huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.82k stars 952 forks source link

When trying to use GPU for SAM inference, I get "matmul is only supported for contiguous tensors lstride" #2373

Open deathknight0718 opened 3 months ago

deathknight0718 commented 3 months ago

this is my test code:

version is 0.6.0

    fn sam() {
        let result: Result<(), Error> = (|| {
            let directory = "/home/foliage/model/candle-sam".to_string();
            let device = Device::new_cuda(0)?;
            let mode = "ST".to_string();
            let builder = match mode.as_str() {
                "PT" => Ok(VarBuilder::from_pth(format!("{}/pytorch_model.bin", directory), DType::F32, &device)?),
                "ST" => Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[format!("{}/sam_vit_b_01ec64.safetensors", directory)], DType::F32, &device)? }),
                _ => Err(Error::MODEL()),
            }?;
            let model = Sam::new(768, 12, 12, &[2, 5, 8, 11], builder)?;
            let path = "/home/foliage/project/foliage/foliage-ai/src/main/rust/native/src/sample.jpg".to_string();
            let image = load_image(&device, &path, Some(IMAGE_SIZE))?;
            let image_data = image.data.to_device(&device)?;
            let bboxes = model.generate_masks(&image_data, /* points_per_side */ 32, /* crop_n_layer */ 0, /* crop_overlap_ratio */ 512. / 1500., /* crop_n_points_downscale_factor */ 1)?;
            for (idx, bbox) in bboxes.iter().enumerate() {
                let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
                let (h, w) = mask.dims2()?;
                let mask = mask.broadcast_as((3, h, w))?;
                image_save(format!("/home/foliage/project/foliage/foliage-ai/src/main/rust/native/src/sam_mask{idx}.png"), Image { data: mask, ih: image.ih, iw: image.iw })?;
            }
            return Ok(());
        })();
        result.unwrap_or_else(|e| {
            println!("failed {}", e.to_string());
        });
    }

this is issue stack, when call model.generate_masks throw error:

---- models::model_sam::tests::sam stdout ----
failed Error! operation failed: matmul is only supported for contiguous tensors lstride: Layout { shape: [64, 256], stride: [1792, 1], start_offset: 256 } rstride: Layout { shape: [256, 256], stride: [1, 256], start_offset: 0 } mnk: (64, 256, 256)
   0: candle_core::error::Error::bt
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.6.0/src/error.rs:231:25
   1: candle_core::cuda_backend::error::<impl core::convert::From<candle_core::cuda_backend::error::CudaError> for candle_core::error::Error>::from
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.6.0/src/cuda_backend/error.rs:50:9
   2: <core::result::Result<T,F> as core::ops::try_trait::FromResidual<core::result::Result<core::convert::Infallible,E>>>::from_residual
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/core/src/result.rs:1989:27
   3: candle_core::cuda_backend::gemm_config
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.6.0/src/cuda_backend/mod.rs:1081:9
   4: <candle_core::cuda_backend::CudaStorage as candle_core::backend::BackendStorage>::matmul
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.6.0/src/cuda_backend/mod.rs:1687:27
   5: candle_core::storage::Storage::matmul
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.6.0/src/storage.rs:723:31
   6: candle_core::tensor::Tensor::matmul
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-core-0.6.0/src/tensor.rs:1197:23
   7: <candle_nn::linear::Linear as candle_core::Module>::forward
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-nn-0.6.0/src/linear.rs:49:17
   8: <candle_transformers::models::with_tracing::Linear as candle_core::Module>::forward
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-transformers-0.6.0/src/models/with_tracing.rs:71:9
   9: <candle_transformers::models::segment_anything::mask_decoder::MlpMaskDecoder as candle_core::Module>::forward
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-transformers-0.6.0/src/models/segment_anything/mask_decoder.rs:48:18
  10: candle_transformers::models::segment_anything::mask_decoder::MaskDecoder::predict_masks
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-transformers-0.6.0/src/models/segment_anything/mask_decoder.rs:219:21
  11: candle_transformers::models::segment_anything::mask_decoder::MaskDecoder::forward
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-transformers-0.6.0/src/models/segment_anything/mask_decoder.rs:159:33
  12: candle_transformers::models::segment_anything::sam::Sam::process_crop
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-transformers-0.6.0/src/models/segment_anything/sam.rs:251:51
  13: candle_transformers::models::segment_anything::sam::Sam::generate_masks
             at /home/foliage/.cargo/registry/src/index.crates.io-6f17d22bba15001f/candle-transformers-0.6.0/src/models/segment_anything/sam.rs:340:21
  14: foliageai::models::model_sam::tests::sam::{{closure}}
             at ./src/models/model_sam.rs:221:26
  15: foliageai::models::model_sam::tests::sam
             at ./src/models/model_sam.rs:208:41
  16: foliageai::models::model_sam::tests::sam::{{closure}}
             at ./src/models/model_sam.rs:207:13
  17: core::ops::function::FnOnce::call_once
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/core/src/ops/function.rs:250:5
  18: core::ops::function::FnOnce::call_once
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/core/src/ops/function.rs:250:5
  19: test::__rust_begin_short_backtrace
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/test/src/lib.rs:625:18
  20: test::run_test_in_process::{{closure}}
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/test/src/lib.rs:648:60
  21: <core::panic::unwind_safe::AssertUnwindSafe<F> as core::ops::function::FnOnce<()>>::call_once
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/core/src/panic/unwind_safe.rs:272:9
  22: std::panicking::try::do_call
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/panicking.rs:559:40
  23: std::panicking::try
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/panicking.rs:523:19
  24: std::panic::catch_unwind
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/panic.rs:149:14
  25: test::run_test_in_process
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/test/src/lib.rs:648:27
  26: test::run_test::{{closure}}
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/test/src/lib.rs:569:43
  27: test::run_test::{{closure}}
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/test/src/lib.rs:599:41
  28: std::sys_common::backtrace::__rust_begin_short_backtrace
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/sys_common/backtrace.rs:155:18
  29: std::thread::Builder::spawn_unchecked_::{{closure}}::{{closure}}
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/thread/mod.rs:542:17
  30: <core::panic::unwind_safe::AssertUnwindSafe<F> as core::ops::function::FnOnce<()>>::call_once
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/core/src/panic/unwind_safe.rs:272:9
  31: std::panicking::try::do_call
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/panicking.rs:559:40
  32: std::panicking::try
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/panicking.rs:523:19
  33: std::panic::catch_unwind
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/panic.rs:149:14
  34: std::thread::Builder::spawn_unchecked_::{{closure}}
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/thread/mod.rs:541:30
  35: core::ops::function::FnOnce::call_once{{vtable.shim}}
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/core/src/ops/function.rs:250:5
  36: <alloc::boxed::Box<F,A> as core::ops::function::FnOnce<Args>>::call_once
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/alloc/src/boxed.rs:2063:9
  37: <alloc::boxed::Box<F,A> as core::ops::function::FnOnce<Args>>::call_once
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/alloc/src/boxed.rs:2063:9
  38: std::sys::pal::unix::thread::Thread::new::thread_start
             at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library/std/src/sys/pal/unix/thread.rs:108:17
  39: start_thread
  40: __clone3

In fact, when I used CPU inference, more than 10 minutes passed and still no results....

deathknight0718 commented 3 months ago

Before that, I checked the continuity via an if code and the tensors were contiguous before the model.generate_masks call.

if !image_data.is_contiguous() {
    image_data = image_data.contiguous()?;
}