tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.41k stars 413 forks source link

CUDA backend does not work with rust nightly #1867

Open jggc opened 3 months ago

jggc commented 3 months ago

Describe the bug This is not actually clear whether this is a bug or a feature/documentation request but here it goes:

Running rust nightly 2024-05-30, no matter how I set up libtorch I will end up with

2024-06-08T16:45:38.148370Z ERROR burn_train::learner::application_logger: PANIC => panicked at /home/user/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.15.0/src/wrappers/tensor_generated.rs:7988:40:                  
called `Result::unwrap()` on an `Err` value: Torch("Could not run 'aten::empty.memory_format' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective....

The reason I am reporting is that this is at least the third time that I encounter this same issue for different reasons such as :

What is my point I think this error is totally unhelpful and there is a loot of room for improvement regarding the setup tch-gpu.

What are you thinking ?

Should we :

  1. Implement pre-flight checks
  2. Improve and consolidate documentation
  3. Improve the error message, reading "operator does not exist" does not hint that well at where the issue is IMHO.
caelunshun commented 3 months ago

I've fixed this by depending on tch version 0.15 and adding tch::maybe_init_cuda() to the start of main(). This seems to stop the linker from removing the libtorch_cuda dependency, which is what causes that error message (at least in my case).

This problem could definitely be documented better; it took me a couple hours to figure this out.

laggui commented 3 months ago

I agree that when running into issues with tch, the actual error is never really clear.

What happens the most often is trying to use the CUDA version when the environment variable was set in another shell (not persistent), so you try to run your program and you get an error similar to the one you posted. Cargo is all sorts of confused and the resolution on tch-rs based on the changes to the environment variable never seemed to work for me, so I end up cleaning the cache and rebuilding the package.

We tried to improve the setup but the environment variables are required by tch-rs, so it is not as straightforward to circumvent (I tried). We could definitely add some documentation for common issues at the very least. The best we can do about the error message from the torch side is probably just try to match the generic error message and give some tips/cues.

We're open to suggestions!