LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.31k stars 344 forks source link

[Question] Cuda Graphs support? #631

Open finnkauski opened 1 year ago

finnkauski commented 1 year ago

A handy tool provided in Torch that I think would make a great addition to the bindings is the CUDAGraph support.

You can see the header file here.

What are your thoughts on this?

LaurentMazare commented 1 year ago

Thanks for the suggestion, here is a small attempt to expose this api, #632. Would you mind trying it out to see if it would work for your use cases?

LaurentMazare commented 1 year ago

Ah actually just adding some example at the moment, it seems that this will require being able to specify some cuda stream to work, so never mind for now I'll add support for specifying the cuda stream and will come back when the basic example runs.

finnkauski commented 1 year ago

Amazing. Think these bits are kind of crucial to expose for those really highly performance based use cases so your work is appreciated. Drop another message here and I'll try to apply the proposed API into our use case.

LaurentMazare commented 1 year ago

You should be able to try out #632 right now if you manually install libtorch (I think it's using 1.13 and not 2.0 on this branch though). Re merging the branch, this is a bit more tricky than what I thought because of https://github.com/pytorch/pytorch/issues/47743 , the issue being that the header files in the libtorch cuda binary packages are not self contained so if there is no local cuda install, compiling this branch will not work. I'm not really sure how to get around this at the moment - vendoring the cuda headers doesn't sound great, maybe this could be put under a feature flag for now and be only available for users that have a proper libtorch install but that doesn't sound great neither.

finnkauski commented 8 months ago

Apologies for the long absence. Priorities changed a lot and only now can I return to trying out tch again. I tried out candle where you're also a contributor so thank you for that too!

I think with respect to this, the API example you have looks very nice to work with although the use case I had is now not really there. That said I think since this was opened a lot of things about the way tch is installed including the automated version detection and installation/downloading was added. I think it makes the idea of vendoring these headers more pallet-able? I also think that with dynamo compilation Python and the ability that it has to do optimisations, having this API perhaps might benefit someone to speed up their Rust versions.