guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.51k stars 211 forks source link

make ZeroShotClassification support the roberta-large-mnli #422

Closed njaard closed 9 months ago

njaard commented 10 months ago

I don't know if it's correct, so a review would be appreciated

guillaume-be commented 9 months ago

Thank you @njaard - this was indeed a mistake. Could you please update the error string a few lines below with "You can only supply a RobertaConfig for Roberta!"?

njaard commented 9 months ago

Done.

Also note that I removed the reference to token_type_ids, which is necessary for it to work, but I don't adequately understand how it works to justify it.

guillaume-be commented 9 months ago

Done.

Also note that I removed the reference to token_type_ids, which is necessary for it to work, but I don't adequately understand how it works to justify it.

Could you please clarify? RobertaConfig is an alias for BertConfig so the first change should not cause any change in behaviour. Could you please provide a reproducible example for the error caused by token_type_ids?

njaard commented 9 months ago

@guillaume-be, with token_type_ids, I get many errors in this form:

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [464,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

And this torch backtrace:

thread '<unnamed>' panicked at 'called `Result::unwrap()` on an `Err` value: Torch("CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`\nException raised from createCublasHandle at ../aten/src/ATen/cuda/CublasHandlePool.cpp:18 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7fdac465a6bb in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libc10.so)\nframe #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbf (0x7fdac46555ef in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libc10.so)\nframe #2: <unknown function> + 0x2f252cb (0x7fdadfd252cb in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #3: at::cuda::getCurrentCUDABlasHandle() + 0x962 (0x7fdadfd26b12 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #4: <unknown function> + 0x2f10a8b (0x7fdadfd10a8b in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #5: <unknown function> + 0x2f41bfa (0x7fdadfd41bfa in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #6: <unknown function> + 0x2c293b1 (0x7fdadfa293b1 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #7: <unknown function> + 0x2c29470 (0x7fdadfa29470 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #8: at::_ops::addmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) + 0xab (0x7fdac69a435b in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #9: <unknown function> + 0x3a6ad24 (0x7fdac846ad24 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #10: <unknown function> + 0x3a6bac2 (0x7fdac846bac2 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #11: at::_ops::addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) + 0x1b1 (0x7fdac6a08de1 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #12: at::native::linear(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) + 0x559 (0x7fdac6256db9 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #13: <unknown function> + 0x29a11f6 (0x7fdac73a11f6 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #14: at::_ops::linear::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) + 0x196 (0x7fdac69f2c86 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #15: <unknown function> + 0x72eb95 (0x563111740b95 in ./program)\nframe #16: <unknown function> + 0x70afff (0x56311171cfff in ./program)\nframe #17: <unknown function> + 0x4b0973 (0x5631114c2973 in ./program)\nframe #18: <unknown function> + 0x4b0fa3 (0x5631114c2fa3 in ./program)\nframe #19: <unknown function> + 0x4b1e46 (0x5631114c3e46 in ./program)\nframe #20: <unknown function> + 0x4f69b6 (0x5631115089b6 in ./program)\nframe #21: <unknown function> + 0x4f8ef2 (0x56311150aef2 in ./program)\nframe #22: <unknown function> + 0x47bc80 (0x56311148dc80 in ./program)\nframe #23: <unknown function> + 0x15d00c (0x56311116f00c in ./program)\nframe #24: <unknown function> + 0x14ef13 (0x563111160f13 in ./program)\nframe #25: <unknown function> + 0x18623c (0x56311119823c in ./program)\nframe #26: <unknown function> + 0x18c4b2 (0x56311119e4b2 in ./program)\nframe #27: <unknown function> + 0x18b827 (0x56311119d827 in ./program)\nframe #28: <unknown function> + 0x18d366 (0x56311119f366 in ./program)\nframe #29: <unknown function> + 0x18e1aa (0x5631111a01aa in ./program)\nframe #30: <unknown function> + 0x85d9c5 (0x56311186f9c5 in ./program)\nframe #31: <unknown function> + 0x89044 (0x7fdac44a8044 in /lib/x86_64-linux-gnu/libc.so.6)\nframe #32: <unknown function> + 0x1095fc (0x7fdac45285fc in /lib/x86_64-linux-gnu/libc.so.6)\n")', /home/charles/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.13.0/src/wrappers/tensor_generated.rs:10865:37
guillaume-be commented 9 months ago

@guillaume-be, with token_type_ids, I get many errors in this form:

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [464,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

And this torch backtrace:

thread '<unnamed>' panicked at 'called `Result::unwrap()` on an `Err` value: Torch("CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`\nException raised from createCublasHandle at ../aten/src/ATen/cuda/CublasHandlePool.cpp:18 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7fdac465a6bb in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libc10.so)\nframe #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbf (0x7fdac46555ef in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libc10.so)\nframe #2: <unknown function> + 0x2f252cb (0x7fdadfd252cb in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #3: at::cuda::getCurrentCUDABlasHandle() + 0x962 (0x7fdadfd26b12 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #4: <unknown function> + 0x2f10a8b (0x7fdadfd10a8b in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #5: <unknown function> + 0x2f41bfa (0x7fdadfd41bfa in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #6: <unknown function> + 0x2c293b1 (0x7fdadfa293b1 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #7: <unknown function> + 0x2c29470 (0x7fdadfa29470 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cuda.so)\nframe #8: at::_ops::addmm::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) + 0xab (0x7fdac69a435b in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #9: <unknown function> + 0x3a6ad24 (0x7fdac846ad24 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #10: <unknown function> + 0x3a6bac2 (0x7fdac846bac2 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #11: at::_ops::addmm::call(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::Scalar const&, c10::Scalar const&) + 0x1b1 (0x7fdac6a08de1 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #12: at::native::linear(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) + 0x559 (0x7fdac6256db9 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #13: <unknown function> + 0x29a11f6 (0x7fdac73a11f6 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #14: at::_ops::linear::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) + 0x196 (0x7fdac69f2c86 in /libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117/lib/libtorch_cpu.so)\nframe #15: <unknown function> + 0x72eb95 (0x563111740b95 in ./program)\nframe #16: <unknown function> + 0x70afff (0x56311171cfff in ./program)\nframe #17: <unknown function> + 0x4b0973 (0x5631114c2973 in ./program)\nframe #18: <unknown function> + 0x4b0fa3 (0x5631114c2fa3 in ./program)\nframe #19: <unknown function> + 0x4b1e46 (0x5631114c3e46 in ./program)\nframe #20: <unknown function> + 0x4f69b6 (0x5631115089b6 in ./program)\nframe #21: <unknown function> + 0x4f8ef2 (0x56311150aef2 in ./program)\nframe #22: <unknown function> + 0x47bc80 (0x56311148dc80 in ./program)\nframe #23: <unknown function> + 0x15d00c (0x56311116f00c in ./program)\nframe #24: <unknown function> + 0x14ef13 (0x563111160f13 in ./program)\nframe #25: <unknown function> + 0x18623c (0x56311119823c in ./program)\nframe #26: <unknown function> + 0x18c4b2 (0x56311119e4b2 in ./program)\nframe #27: <unknown function> + 0x18b827 (0x56311119d827 in ./program)\nframe #28: <unknown function> + 0x18d366 (0x56311119f366 in ./program)\nframe #29: <unknown function> + 0x18e1aa (0x5631111a01aa in ./program)\nframe #30: <unknown function> + 0x85d9c5 (0x56311186f9c5 in ./program)\nframe #31: <unknown function> + 0x89044 (0x7fdac44a8044 in /lib/x86_64-linux-gnu/libc.so.6)\nframe #32: <unknown function> + 0x1095fc (0x7fdac45285fc in /lib/x86_64-linux-gnu/libc.so.6)\n")', /home/charles/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.13.0/src/wrappers/tensor_generated.rs:10865:37

Thank you for the feedback. This was due to a bug in the tokenization crate. I have pushed an update for the latter (v8.1.1 should solve the issue). Can you please try updating the dependencies and try again, reverting the change to token_type_ids?