Properly cancel request channels on disconnect

This commit is contained in:
Dominik Nakamura 2021-03-15 14:56:14 +09:00
parent 00c0204285
commit e41396407f
No known key found for this signature in database
GPG key ID: E4C6A749B2491910

View file

@ -81,7 +81,7 @@ pub struct Client {
/// A list of currently waiting requests to get a response back. The key is the string version
/// of a request ID and the value is a oneshot sender that allows to send the response back to
/// the other end that waits for the response.
receivers: Arc<Mutex<HashMap<String, oneshot::Sender<serde_json::Value>>>>,
receivers: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
/// Broadcast sender that distributes received events to all current listeners. Events are
/// dropped if nobody listens.
#[cfg(feature = "events")]
@ -162,10 +162,7 @@ impl Client {
.map_err(Error::Connect)?;
let (write, mut read) = socket.split();
let receivers = Arc::new(Mutex::new(HashMap::<
String,
oneshot::Sender<serde_json::Value>,
>::new()));
let receivers = Arc::new(Mutex::new(HashMap::<_, oneshot::Sender<_>>::new()));
let receivers2 = Arc::clone(&receivers);
#[cfg(feature = "events")]
let (event_sender, _) =
@ -194,9 +191,10 @@ impl Client {
.as_object()
.and_then(|obj| obj.get("message-id"))
.and_then(|id| id.as_str())
.and_then(|id| id.parse().ok())
{
debug!("got message with id {}", message_id);
if let Some(tx) = receivers2.lock().await.remove(message_id) {
if let Some(tx) = receivers2.lock().await.remove(&message_id) {
tx.send(json).ok();
}
} else {
@ -226,6 +224,10 @@ impl Client {
};
events_tx.send(event).ok();
}
// clear all outstanding receivers to stop them from waiting forever on responses
// they'll never receive.
receivers2.lock().await.clear();
});
let write = Mutex::new(write);
@ -273,9 +275,9 @@ impl Client {
where
T: DeserializeOwned,
{
let id = self.id_counter.fetch_add(1, Ordering::SeqCst).to_string();
let id = self.id_counter.fetch_add(1, Ordering::SeqCst);
let req = Request {
message_id: &id,
message_id: &id.to_string(),
ty: req,
};
let json = serde_json::to_string(&req).map_err(Error::SerializeMessage)?;
@ -284,12 +286,18 @@ impl Client {
self.receivers.lock().await.insert(id, tx);
debug!("sending message: {}", json);
self.write
let write_result = self
.write
.lock()
.await
.send(Message::Text(json))
.await
.map_err(Error::Send)?;
.map_err(Error::Send);
if let Err(e) = write_result {
self.receivers.lock().await.remove(&id);
return Err(e);
}
let mut resp = rx.await.map_err(Error::ReceiveMessage)?;