Totodore / socketioxide

A socket.io server implementation in Rust that integrates with the Tower ecosystem and the Tokio stack.
https://docs.rs/socketioxide
MIT License
1.13k stars 49 forks source link

Cannot keep ref to extension objects across .await points in handlers/middlewares #295

Closed X-OrBit closed 1 month ago

X-OrBit commented 3 months ago

Describe the bug Cannot use socket extensions inside async namespace middleware

To Reproduce Code that reproduce error:

main.rs

use socketioxide::extensions::{Extensions, Ref};
use socketioxide::extract::SocketRef;
use socketioxide::handler::ConnectHandler;
use socketioxide::SocketIo;
use tokio::net::TcpListener;

fn handler(s: SocketRef) {
    println!("socket connected on / namespace with id: {}", s.id);
}

async fn test_async() -> Result<(), reqwest::Error> {
    reqwest::get("https://google.com").await?;
    Ok(())
}

async fn middleware(socket: SocketRef) -> Result<(), socketioxide::SocketError<()>> {
    let ext: &Extensions = &socket.extensions;
    let number: Option<Ref<i32>> = ext.get::<i32>(); // if comment this line
    let _ = test_async().await;                      // or this one, then there will be no compilation error 
    Ok(())
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let (layer, io) = SocketIo::new_layer();
    io.ns("/", handler.with(middleware));
    let app = axum::Router::new().layer(layer);

    let listener = TcpListener::bind("127.0.0.1:5002").await.unwrap();

    axum::serve(listener, app).await?;

    Ok(())
}

Cargo.toml

[package]
name = "test"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = { version = "0.7.5" }
socketioxide = { version = "0.12.0", features = ["extensions"] }
reqwest = "0.12.2"

Error:

error[E0277]: the trait bound `fn(SocketRef) -> impl Future<Output = Result<(), SocketError<()>>> {middleware}: ConnectMiddleware<LocalAdapter, _>` is not satisfied
   --> src/bin/test_socket.rs:25:24
    |
25  |     io.ns("/", handler.with(middleware));
    |                        ^^^^ the trait `ConnectMiddleware<LocalAdapter, _>` is not implemented for fn item `fn(SocketRef) -> impl Future<Output = Result<(), SocketError<()>>> {middleware}`
    |
note: required by a bound in `ConnectHandler::{opaque#0}`
   --> /Users/xorbit/.cargo/registry/src/index.crates.io-6f17d22bba15001f/socketioxide-0.12.0/src/handler/connect.rs:239:12
    |
239 |         M: ConnectMiddleware<A, T1> + Send + Sync + 'static,
    |            ^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `ConnectHandler::{opaque#0}`

Expected behavior The code is being compiled, extensions are working properly

Versions (please complete the following information):

Additional context -

Totodore commented 3 months ago

It is because the Ref is not Send. Therefore if you use it across .await points, the returned future will not be Send. It is a requirement for the returned future to be Send so it can be later boxed.

One possible way to solve this might be to change the Dashmap that is backing the Extensions by a tokio::RwLock<HashMap> so that LockGuard are Send.

Currently the only solution to solve your issue is either to constraint your ref to a limited scope before your .await point:


async fn middleware(socket: SocketRef) -> Result<(), socketioxide::SocketError<()>> {
    {
         let ext: &Extensions = &socket.extensions;
         let number: Option<Ref<i32>> = ext.get::<i32>();
    }
    let _ = test_async().await;
    Ok(())
}

or to clone the content and discard the Ref.

X-OrBit commented 3 months ago

Thanks for the answer, I just tried to use cloning for now, as the objects where not too complicated. But the solution with the limitation of the scope is good one, did not figure out it myself

Are there any plans to add rwlock/mutex to extensions?

Totodore commented 2 months ago

Actually, it is global to every handlers (not only middlewares). The only thing I could do is to provide a Extension extractor that directly get and clone the object like with the Extension extractor from axum. I might also add an HttpExtension extractor that would correspond to the http request extension.