coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks
Other
1.71k stars 99 forks source link

Send/Sync for Device #888

Open blogle opened 10 months ago

blogle commented 10 months ago

Currently it is not possible to move devices or tensors across threads.

https://github.com/coreylowman/dfdx/blob/4476b5ee19dc9cc446388545560c57e80cb086c8/dfdx-core/src/tensor/cuda/device.rs#L22-L33

Despite everything in the Device being wrapped in an Arc, the Cuda device is not actually Send or Sync - almost certainly because some mut* are nested in there for ffi. As a result it is rather burdensome to implement various functionality. As an example: if you want to implement a pipeline where tensors are prepared in thread A, while inference is done in thread B - tensors cant be moved across threads copying to the host and back (serializing to a vec), and device methods cant be called from both threads.

I am not familiar enough with the underlying implementations to understand if the device can implement Sync, but it would be great if at a minimum the device could be sent across threads. As of right now I could certainly just create a device per thread, but I am not sure how much overhead is associated with doing so - or the implication of not sharing the caches across instances.

blogle commented 10 months ago

I am realizing for the example I described, you wouldn't actually get any parallelism from the addition of threads since the copy (assuming its synchronous) and inference kernels will be interleaved on the same cuda stream.

Nevertheless it would be great if something like the following pseudo code was possible

fn inference_server(results: Sender<Tensor>) -> Sender<Request> {
  let dev = dfdx::AutoDevice::default();
  let model = dev.build_module::<ResNet, f32>();
  let (tx, rx) = tokio::sync::mpsc::channel(256);
  let inferencer = UnboundedReceiverStream(rx)
      .map(|data| preprocess(data))
      .ready_chunks(32)
      .map(|batch_vec| dev.tensor(batch_vec))
      .for_each(|tensor| tokio::spawn_blocking(|| {
          results.send(model.forward(tensor))
      }));

  tokio::spawn(inferencer);

  tx
}