guillaume-be / rust-bert

Rust native ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)
https://docs.rs/crate/rust-bert
Apache License 2.0
2.51k stars 211 forks source link

GPT-2 text generation throws an unexpected error #436

Closed kj3moraes closed 7 months ago

kj3moraes commented 7 months ago

I have been facing an error when prompting TextGeneration models. The whole trace is below

thread '<unnamed>' panicked at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.13.0/src/wrappers/tensor_generated.rs:9631:48:
called `Result::unwrap()` on an `Err` value: Torch("index_copy_(): index 56 is out of bounds for dimension 0 with size 56\nException raised from operator() at /pytorch/aten/src/ATen/native/cpu/IndexKernel.cpp:302 (most recent call first):\nframe #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7fc114c5a6bb in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libc10.so)\nframe #1: <unknown function> + 0x569bda7 (0x7fc11a69bda7 in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #223 at::TensorIteratorBase::serial_for_each(c10::function_ref<void (char**, long const*, long, long)>, at::Range) const + 0x1d9 (0x7fc116413879 in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #3: at::TensorIteratorBase::for_each(c10::function_ref<void (char**, long const*, long, long)>, long) + 0x167 (0x7fc116413ed7 in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #4: <unknown function> + 0x5684b6e (0x7fc11a684b6e in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #5: <unknown function> + 0x5685440 (0x7fc11a685440 in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #6: at::native::structured_index_copy_out::impl(at::Tensor const&, long, at::Tensor const&, at::Tensor const&, at::Tensor const&) + 0x4eb (0x7fc116a8c32b in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #7: <unknown function> + 0x26417af (0x7fc1176417af in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #8: at::_ops::index_copy_::redispatch(c10::DispatchKeySet, at::Tensor&, long, at::Tensor const&, at::Tensor const&) + 0x9b (0x7fc116f921ab in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #9: <unknown function> + 0x43b93a1 (0x7fc1193b93a1 in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #10: at::_ops::index_copy_::redispatch(c10::DispatchKeySet, at::Tensor&, long, at::Tensor const&, at::Tensor const&) + 0x9b (0x7fc116f921ab in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #11: <unknown function> + 0x3a8bf10 (0x7fc118a8bf10 in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #12: at::_ops::index_copy_::call(at::Tensor&, long, at::Tensor const&, at::Tensor const&) + 0x17c (0x7fc116ffb41c in /home/kjmoraes/Coding/libtorch-2.0.0/libtorch/lib/libtorch_cpu.so)\nframe #13: <unknown function> + 0xff3ee4 (0x55e566b44ee4 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #14: <unknown function> + 0xffab9b (0x55e566b4bb9b in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #15: <unknown function> + 0xf74c78 (0x55e566ac5c78 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #16: <unknown function> + 0xf7dfef (0x55e566acefef in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #17: <unknown function> + 0xc7ce5b (0x55e5667cde5b in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #18: <unknown function> + 0xc4cf02 (0x55e56679df02 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #19: <unknown function> + 0xbb6dc5 (0x55e566707dc5 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #20: <unknown function> + 0xc83840 (0x55e5667d4840 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #21: <unknown function> + 0x8a4fc8 (0x55e5663f5fc8 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #22: <unknown function> + 0x780fd3 (0x55e5662d1fd3 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #23: <unknown function> + 0x780552 (0x55e5662d1552 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #24: <unknown function> + 0x345eb9 (0x55e565e96eb9 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #25: <unknown function> + 0x9b5b6e (0x55e566506b6e in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #26: <unknown function> + 0x9b69d2 (0x55e5665079d2 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #27: <unknown function> + 0x66a48e (0x55e5661bb48e in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #28: <unknown function> + 0xa0c084 (0x55e56655d084 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #29: <unknown function> + 0xa0a404 (0x55e56655b404 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #30: <unknown function> + 0x8431a5 (0x55e5663941a5 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #31: <unknown function> + 0x93e5c6 (0x55e56648f5c6 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #32: <unknown function> + 0x6fe104 (0x55e56624f104 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #33: <unknown function> + 0x93e421 (0x55e56648f421 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #34: <unknown function> + 0x89ef6d (0x55e5663eff6d in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #35: <unknown function> + 0x586bda (0x55e5660d7bda in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #36: <unknown function> + 0x583eb9 (0x55e5660d4eb9 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #37: <unknown function> + 0x586d54 (0x55e5660d7d54 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #38: <unknown function> + 0x8aa59e (0x55e5663fb59e in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #39: <unknown function> + 0x7fb8e6 (0x55e56634c8e6 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #40: <unknown function> + 0x47bb80 (0x55e565fccb80 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #41: <unknown function> + 0x983cd0 (0x55e5664d4cd0 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #42: <unknown function> + 0x61dd16 (0x55e56616ed16 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #43: <unknown function> + 0x63547b (0x55e56618647b in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #44: <unknown function> + 0x5fcddf (0x55e56614dddf in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #45: <unknown function> + 0x47b9c5 (0x55e565fcc9c5 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #46: <unknown function> + 0x525e8e (0x55e566076e8e in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #47: <unknown function> + 0x28ab315 (0x55e5683fc315 in /home/kjmoraes/Coding/fable/monocle/src-tauri/target/debug/unnamed)\nframe #48: <unknown function> + 0x9f822 (0x7fc10ee9f822 in /lib64/libc.so.6)\nframe #49: <unknown function> + 0x3f450 (0x7fc10ee3f450 in /lib64/libc.so.6)\n")
stack backtrace:
   0: rust_begin_unwind
             at /rustc/79e9716c980570bfd1f666e3b16ac583f0168962/library/std/src/panicking.rs:597:5
   1: core::panicking::panic_fmt
             at /rustc/79e9716c980570bfd1f666e3b16ac583f0168962/library/core/src/panicking.rs:72:14
   2: core::result::unwrap_failed
             at /rustc/79e9716c980570bfd1f666e3b16ac583f0168962/library/core/src/result.rs:1652:5
   3: core::result::Result<T,E>::unwrap
             at /rustc/79e9716c980570bfd1f666e3b16ac583f0168962/library/core/src/result.rs:1077:23
   4: tch::wrappers::tensor_generated::<impl tch::wrappers::tensor::Tensor>::index_copy_
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.13.0/src/wrappers/tensor_generated.rs:9631:9
   5: rust_bert::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator::generate_beam_search
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rust-bert-0.21.0/src/pipelines/generation_utils.rs:1545:25
   6: rust_bert::pipelines::generation_utils::LanguageGenerator::generate_from_ids_and_past::{{closure}}
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rust-bert-0.21.0/src/pipelines/generation_utils.rs:2131:17
   7: tch::wrappers::tensor::no_grad
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tch-0.13.0/src/wrappers/tensor.rs:814:18
   8: rust_bert::pipelines::generation_utils::LanguageGenerator::generate_from_ids_and_past
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rust-bert-0.21.0/src/pipelines/generation_utils.rs:2129:44
   9: rust_bert::pipelines::generation_utils::LanguageGenerator::generate_indices
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rust-bert-0.21.0/src/pipelines/generation_utils.rs:1902:9
  10: rust_bert::pipelines::text_generation::TextGenerationOption::generate_indices
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rust-bert-0.21.0/src/pipelines/text_generation.rs:349:38
  11: rust_bert::pipelines::text_generation::TextGenerationModel::generate
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/rust-bert-0.21.0/src/pipelines/text_generation.rs:611:26
  12: app::parser::summaries::summarize_file::{{closure}}
             at ./src/parser/summaries.rs:100:32
  13: app::embed::cookie_builder::CookieBuilder::bake_biscuit::{{closure}}
             at ./src/embed/cookie_builder.rs:30:54
  14: app::embed::cookie_builder::CookieBuilder::bake_biscuits::{{closure}}
             at ./src/embed/cookie_builder.rs:57:53
  15: app::indexer::indexer::Indexer::index::{{closure}}
             at ./src/indexer/indexer.rs:40:80
  16: tokio::runtime::park::CachedParkThread::block_on::{{closure}}
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/park.rs:282:63
  17: tokio::runtime::coop::with_budget
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/coop.rs:107:5
  18: tokio::runtime::coop::budget
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/coop.rs:73:5
  19: tokio::runtime::park::CachedParkThread::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/park.rs:282:31
  20: tokio::runtime::context::blocking::BlockingRegionGuard::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/context/blocking.rs:66:9
  21: tokio::runtime::scheduler::multi_thread::MultiThread::block_on::{{closure}}
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/scheduler/multi_thread/mod.rs:87:13
  22: tokio::runtime::context::runtime::enter_runtime
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/context/runtime.rs:65:16
  23: tokio::runtime::scheduler::multi_thread::MultiThread::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/scheduler/multi_thread/mod.rs:86:9
  24: tokio::runtime::runtime::Runtime::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tokio-1.33.0/src/runtime/runtime.rs:350:45
  25: tauri::async_runtime::Runtime::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tauri-1.5.2/src/async_runtime.rs:126:25
  26: tauri::async_runtime::GlobalRuntime::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tauri-1.5.2/src/async_runtime.rs:71:7
  27: tauri::async_runtime::block_on
             at /home/kjmoraes/.cargo/registry/src/index.crates.io-6f17d22bba15001f/tauri-1.5.2/src/async_runtime.rs:264:3
  28: app::launch_background_process::{{closure}}
             at ./src/main.rs:85:9
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

This occurs in this line in my code

// This needs to be done because the model downlaod is an blocking task
// which is traditionally not allowed inside an async context.
let blocking_task =
      task::spawn_blocking(|| TextGenerationModel::new(Default::default()).unwrap());
let model = blocking_task.await.unwrap();
let file_contents = fs::read_to_string(file_path).unwrap();
let mut prompt = read_to_string("./src/parser/prompts/text_summarization.prompt").
prompt.push_str(&file_contents);
let model_output = model.generate(&vec![prompt], None);

I append the file contents onto a prompt that is no more than ~200 tokens.

Do you know why this might be happening ?

guillaume-be commented 7 months ago

Hello @kj3moraes ,

Thank you for raising this. Could you please share a reproducible example (i.e. with a prompt and file you could share publicly) and turning the sampling off for the generation?

guillaume-be commented 7 months ago

@kj3moraes just had a look and I believe I have found the issue. The default maximum length for the text generation pipeline is 56 tokens (this includes the input prompt). Your input prompt most likely exceeds the 56 token default of the pipeline, and you have 2 solutions:

Please let me know how this works. I will add a check to improve error handling when users provide inputs longer than the maximum length, thank you for raising this.

kj3moraes commented 7 months ago

Thank you for raising this. Could you please share a reproducible example (i.e. with a prompt and file you could share publicly) and turning the sampling off for the generation?

This is the prompt that I use

Task: Extract summary and keywords from code

You are given a file containing code in a programming language. Your task is to reads the code from the file and generates a JSON output with two keys - 'summary' and 'keywords'.

1. 'summary': A string describing what the code is doing. This summary should capture the main purpose or functionality of the code in a concise manner.

2. 'keywords': A list of strings that includes relevant keywords related to the programming language used, the task being performed, or any significant terms present in the code.

The following is an example:

INPUT: 
from transformers import pipeline

image_to_text_model = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")

def convert_image_to_text(image_path: str) -> str :
    return image_to_text_model(image_path)

{
    "summary": "Function to caption the image at the specified path and returns it as a string",
    "keywords": ["python", "huggingface", "transformers", "image", "caption", "captioning"] 
}

INPUT:

An example file would be

use std::path::PathBuf;

pub fn path_to_string(path: &PathBuf) -> String {
    path.display().to_string()
}

pub fn path_to_filename_string(path: &PathBuf) -> Option<String> {
    Some(path.file_name()?.to_str()?.to_string())
}
kj3moraes commented 7 months ago
* Change the `max_length` value of the  [`TextGenerationConfig`](https://github.com/guillaume-be/rust-bert/blob/9f2cd17e914dee9570e981c63a4021beb33250c2/src/pipelines/text_generation.rs#L59) (you are currently using the default constructor). This is probably what you want to to.

This worked thanks a lot. Returning a Result would be better for sure.