snapview / tokio-tungstenite

Future-based Tungstenite for Tokio. Lightweight stream-based WebSocket implementation
MIT License
1.88k stars 236 forks source link

Question/Help Needed: How to properly handle a WebSocket close? #292

Closed BlueGradientHorizon closed 1 year ago

BlueGradientHorizon commented 1 year ago

Hello. I'm trying to write a client-server application, but I'm having some issues/misunderstandings related to handling the WebSocket close event. The source code below includes both the client and the server. I tried to shorten the code as much as I could, but it still turned out to be long, I'm sorry.

Cargo.toml:

[package]
...

[dependencies]
tokio = { version = "1", features = ["full"] }
tokio-stream="0.1"
tokio-tungstenite="0.19"
futures-channel = "0.3"
futures-util="0.3"
url="2.4"

main.rs:

use std::{
    sync::{Arc, Mutex as StdMutex},
    collections::HashMap,
    net::SocketAddr, time::Duration
};

use futures_util::{StreamExt, TryStreamExt, future, pin_mut, lock::Mutex as FutMutex, SinkExt};
use tokio::{sync::mpsc, net::{TcpListener, TcpStream}, io::{AsyncWriteExt, AsyncReadExt}};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_tungstenite::{tungstenite::{Message, self}, connect_async};

/*
    SERVER IMPLEMENTATION
 */

async fn server_main() -> Result<(), std::io::Error> {
    let (mut _handle, cb_receiver) = MyServerHandle::new("0.0.0.0:12345".to_string()).await;
    _handle.run().await;
    tokio::spawn(async move { process_server_cb_message(cb_receiver).await });

    tokio::time::sleep(Duration::from_secs(7)).await;
    _handle.stop().await;

    tokio::time::sleep(Duration::from_millis(500)).await;

    Ok(())
}

async fn process_server_cb_message(mut cb_receiver: mpsc::Receiver<ServerCallbackMessage>) {
    while let Some(_msg) = cb_receiver.recv().await { }
}

#[derive(Clone, Debug)]
pub struct Client {
    tx: Option<mpsc::UnboundedSender<Message>>
}

impl Client {
    pub fn new(tx: mpsc::UnboundedSender<Message>) -> Self {
        Client { tx: Some(tx) }
    }

    pub fn disconnect(&mut self) {
        self.tx = None;
        println!("disconnect() end");
    }

    pub fn process_incoming(&self, msg: Message) -> Result<(), tungstenite::Error> {
        println!("msg: {}", msg.to_text().unwrap().trim()); Ok(())
    }
}

type ClientMap = Arc<StdMutex<HashMap<SocketAddr, Arc<StdMutex<Client>>>>>;

struct Server {
    cb_sender: mpsc::Sender<ServerCallbackMessage>,
    receiver: mpsc::Receiver<ServerMessage>,
    client_map: ClientMap,
    listener: TcpListener,
    listener_stop_tx: mpsc::Sender<()>,
    listener_stop_rx: mpsc::Receiver<()>
}

#[derive(Debug)]
pub enum ServerMessage { Stop }
#[derive(Debug)]
pub enum ServerCallbackMessage { }

impl Server {
    pub async fn new(cb_sender: mpsc::Sender<ServerCallbackMessage>, receiver: mpsc::Receiver<ServerMessage>, addr: String) -> Self {
        let client_map = ClientMap::new(StdMutex::new(HashMap::new()));
        let listener = TcpListener::bind(&addr).await.unwrap();
        let (listener_stop_tx, listener_stop_rx) = mpsc::channel(1);

        Server {
            cb_sender,
            receiver,
            client_map,
            listener,
            listener_stop_tx,
            listener_stop_rx
        }
    }

    pub async fn run(&mut self) {
        tokio::select! {
            _ = Self::run_connection_handler(&self.listener, self.client_map.clone()) => (),
            _ = Self::run_message_handler(&mut self.receiver, self.cb_sender.clone(), self.listener_stop_tx.clone()) => (),
            _ = self.listener_stop_rx.recv() => self.disconnect_all().await
        };
        println!("run() end");
    }

    /* CONNECTION HANDLER */

    async fn run_connection_handler(listener: &TcpListener, client_map: ClientMap) {
        tokio::select! {
            _ = async {
                loop {
                    match listener.accept().await {
                        Ok((stream, addr)) => { tokio::spawn(Self::handle_connection(client_map.clone(), stream, addr)); },
                        Err(e) => println!("e: {}", e),
                    };
                }
            } => (),
        };
    }

    async fn handle_connection(client_map: ClientMap, raw_stream: TcpStream, addr: SocketAddr) {
        let ws_stream = tokio_tungstenite::accept_async(raw_stream).await.unwrap();
        println!("WebSocket connection established: {}", addr);

        let (client_tx, client_rx) = mpsc::unbounded_channel::<Message>();
        let client_rx = UnboundedReceiverStream::new(client_rx);

        let client = Arc::new(StdMutex::new(Client::new(client_tx)));
        client_map.lock().unwrap().insert(addr, Arc::clone(&client));

        let (mut outgoing, incoming) = ws_stream.split();

        let process_incoming = incoming.try_for_each(|msg| {
            match client.lock().unwrap().process_incoming(msg) {
                Ok(_) => return future::ok(()),
                Err(e) => return future::err(e),
            };
        });

        let client_rx_to_outgoing = client_rx.map(Ok).forward(&mut outgoing);

        pin_mut!(process_incoming, client_rx_to_outgoing);

        tokio::select! {
            v = process_incoming => println!("v: {:?}", v),
            _ = client_rx_to_outgoing => { println!("client_rx_to_outgoing ended") }
        }

        outgoing.close().await.unwrap();

        println!("{} disconnected", &addr);
        client_map.lock().unwrap().remove(&addr);
    }

    /* MESSAGE HANDLER */

    async fn run_message_handler(receiver: &mut mpsc::Receiver<ServerMessage>, cb_sender: mpsc::Sender<ServerCallbackMessage>, listener_stop_tx: mpsc::Sender<()>) {
        while let Some(msg) = receiver.recv().await {
            Self::handle_message(msg, cb_sender.clone(), listener_stop_tx.clone()).await;
        }
    }

    async fn handle_message(msg: ServerMessage, _cb_sender: mpsc::Sender<ServerCallbackMessage>, listener_stop_tx: mpsc::Sender<()>) {
        match msg {
            ServerMessage::Stop => {
                listener_stop_tx.send(()).await.unwrap();
            }
        }
    }

    async fn disconnect_all(&mut self) {
        for (_, client) in self.client_map.lock().unwrap().clone().into_iter() {
            client.lock().unwrap().disconnect();
        }
    }
}

pub struct MyServerHandle {
    server: Arc<FutMutex<Server>>,
    sender: mpsc::Sender<ServerMessage>
}

impl MyServerHandle {
    pub async fn new(addr: String) -> (Self, mpsc::Receiver<ServerCallbackMessage>) {
        let (sender, receiver) = mpsc::channel(8);
        let (cb_sender, cb_receiver) = mpsc::channel(8);
        let server = Arc::new(FutMutex::new(Server::new(cb_sender, receiver, addr).await));
        (Self { server, sender }, cb_receiver)
    }

    pub async fn run(&mut self) {
        let server = Arc::clone(&self.server);
        tokio::spawn(async move { server.lock().await.run().await });
    }

    pub async fn stop(&self) {
        self.sender.send(ServerMessage::Stop).await.unwrap();
    } 
}

/*
    CLIENT IMPLEMENTATION
 */

async fn client_main() {
    let (stdin_tx, stdin_rx) = futures_channel::mpsc::unbounded();
    tokio::spawn(read_stdin(stdin_tx.clone()));

    let url = url::Url::parse("ws://127.0.0.1:12345").unwrap();
    let (ws_stream, _) = connect_async(url).await.unwrap();
    println!("WebSocket handshake has been successfully completed");

    let (write, mut read) = ws_stream.split();

    let stdin_to_ws = stdin_rx.map(Ok).forward(write);
    let ws_to_stdout = {
        async {
            while let Some(message) = read.next().await {
                println!("got msg: {message:?}");
                match message {
                    Ok(message) => match message {
                        Message::Binary(data) => {
                            let mut stdout = tokio::io::stdout();
                            stdout.write_all(&data).await.unwrap();
                            stdout.write_all(&[b'\n']).await.unwrap();
                            stdout.flush().await.unwrap();
                        },
                        Message::Close(f) => println!("close message was received with CloseFrame: {f:?}"),
                        _ => (),
                    },
                    Err(e) => match e {
                        tokio_tungstenite::tungstenite::Error::ConnectionClosed => { println!("ConnectionClosed") },
                        tokio_tungstenite::tungstenite::Error::AlreadyClosed => { println!("AlreadyClosed") },
                        tokio_tungstenite::tungstenite::Error::Io(e) => { 
                            println!("Io: {e}");
                            match stdin_tx.unbounded_send(Message::binary("".to_string())) {
                                Ok(_) => println!("ok"), 
                                Err(_) => println!("err")
                            };
                        },
                        tokio_tungstenite::tungstenite::Error::Protocol(e) => { println!("Protocol: {e}") },
                        _ => (),
                    }
                };
            }
        }
    };

    pin_mut!(stdin_to_ws, ws_to_stdout);
    future::select(stdin_to_ws, ws_to_stdout).await;
}

async fn read_stdin(tx: futures_channel::mpsc::UnboundedSender<Message>) {
    let mut stdin = tokio::io::stdin();
    loop {
        let mut buf = vec![0; 1024];
        let n = match stdin.read(&mut buf).await {
            Err(_) | Ok(0) => break,
            Ok(n) => n,
        };
        buf.truncate(n);
        tx.unbounded_send(Message::binary(buf)).unwrap();
    }
}

#[tokio::main]
async fn main() {
    match std::env::args().nth(1) {
        Some(arg) => {
            match arg.as_str() {
                "server" => server_main().await.unwrap(),
                "client" => client_main().await,
                _ => println!("please pass one argument (`server` or `client`)")
            }
        },
        None => println!("please pass one argument (`server` or `client`)")
    }
}

To start the server or client, pass the appropriate command line argument:

cargo run -- server/client

The point is that when, 7 seconds after the server starts (I did this for a test), it disconnects all connected clients, an error may appear on the client with NOT a 100% chance. You may need to make 5 to 20 attempts to start the server and client to catch this error. Client console output:

WebSocket handshake has been successfully completed
got msg: Ok(Close(None))
close message was received with CloseFrame: None
got msg: Err(Io(Os { code: 10053, kind: ConnectionAborted, message: "A program on your host machine aborted an established connection." }))
Io: A program on your host machine terminated the established connection. (os error 10053)
ok

I'm on Windows 11. Each time the error code may be slightly different, but the meaning is about the same. Questions:

  1. Is my server shutting down the client correctly? Please comment on the disconnect() method of Client and line 137 (outgoing.close().await.unwrap();).
  2. Is my client accepting/processing messages correctly (while let Some(message) = read.next().await { ... })?
  3. If the client receives Message::Close, should I take some action myself to close the WebSocket connection, or should tokio_tungstenite do it by itself?
  4. Should I end up with a ConnectionClosed error from the next() method? Because now after receiving the message Message::Close I get the error Io.
  5. Line 225: why is "ok" being printed to the console even after receiving an error about disconnecting the connection?

Maybe I didn't see something in the official documentation. I recently started learning the Rust language and the tokio and tokio-tungstenite crates. If you see that I did something wrong in the code, or not as usual, too complicated, or vice versa, please also let me know. Thank you for your time.

daniel-abramov commented 1 year ago

The code is quite large, so I only skimmed through it (it looks like it's totally possible to rewrite the code in a simpler and more concise way to demonstrate the problem), so I'll just go through your questions and try to answer them:

  1. Closing close() on the client/sink is fine (it will essentially just send and flush the close frame).
  2. Reading the message - yes, processing it - depends on what you're trying to achieve.
  3. If you're talking about the closing handshake, then tungstenite handles it for you, i.e. there is no need to do any further actions upon receival of the close message. That being said, some users may want to do something in such cases should their application logic require it.
  4. As long as you use the latest version, you should not, ConnectionClosed are generally transparently converted into None on the receiving stream. You're getting a different error though (IoError / ConnectionAborted), so there is likely something wrong with your application logic.
  5. There is no reason why it should not. As per the code, the message is printed if sending to the channel succeeds. The channel in question is a regular async mpsc::unbounded(). Sending to the channel will succeed as long as the receiver is alive at the moment of sending.

If you see that I did something wrong in the code, or not as usual, too complicated, or vice versa, please also let me know.

In terms of complexity, it generally feels like it could be simplified a lot. In terms of "not as usual" - I'd say there are too many mutexes that seem to be redundant (not necessary, could be rewritten in a different way). But this will come with time and practice (and also while getting familiar with Tokio more).

Hope this was helpful!