huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.78k stars 749 forks source link

candle

discord server Latest version Documentation License

Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: whisper, LLaMA2, T5, yolo, Segment Anything.

Get started

Make sure that you have candle-core correctly installed as described in Installation.

Let's see how to run a simple matrix multiplication. Write the following to your myapp/src/main.rs file:

use candle_core::{Device, Tensor};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;

    let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
    let b = Tensor::randn(0f32, 1., (3, 4), &device)?;

    let c = a.matmul(&b)?;
    println!("{c}");
    Ok(())
}

cargo run should display a tensor of shape Tensor[[2, 4], f32].

Having installed candle with Cuda support, simply define the device to be on GPU:

- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;

For more advanced examples, please have a look at the following section.

Check out our examples

These online demos run entirely in your browser:

We also provide a some command line based examples using state of the art models:

Run them using commands like:

cargo run --example quantized --release

In order to use CUDA add --features cuda to the example command line. If you have cuDNN installed, use --features cudnn for even more speedups.

There are also some wasm examples for whisper and llama2.c. You can either build them with trunk or try them online: whisper, llama2, T5, Phi-1.5, and Phi-2, Segment Anything Model.

For LLaMA2, run the following command to retrieve the weight files and start a test server:

cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081

And then head over to http://localhost:8081/.

Useful External Resources

If you have an addition to this list, please submit a pull request.

Features

How to use

Cheatsheet:

Using PyTorch Using Candle
Creation torch.Tensor([[1, 2], [3, 4]]) Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
Creation torch.zeros((2, 2)) Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?
Indexing tensor[:, :4] tensor.i((.., ..4))?
Operations tensor.view((2, 2)) tensor.reshape((2, 2))?
Operations a.matmul(b) a.matmul(&b)?
Arithmetic a + b &a + &b
Device tensor.to(device="cuda") tensor.to_device(&Device::new_cuda(0)?)?
Dtype tensor.to(dtype=torch.float16) tensor.to_dtype(&DType::F16)?
Saving torch.save({"A": A}, "model.bin") candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?
Loading weights = torch.load("model.bin") candle::safetensors::load("model.safetensors", &device)

Structure

FAQ

Why should I use Candle?

Candle's core goal is to make serverless inference possible. Full machine learning frameworks like PyTorch are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight binaries.

Secondly, Candle lets you remove Python from production workloads. Python overhead can seriously hurt performance, and the GIL is a notorious source of headaches.

Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like safetensors and tokenizers.

Other ML frameworks

Common Errors

Missing symbols when compiling with the mkl feature.

If you get some missing symbols when compiling binaries/tests using the mkl or accelerate features, e.g. for mkl you get:

  = note: /usr/bin/ld: (....o): in function `blas::sgemm':
          .../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status

  = note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
  = note: use the `-l` flag to specify native libraries to link
  = note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo

or for accelerate:

Undefined symbols for architecture arm64:
            "_dgemm_", referenced from:
                candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
            "_sgemm_", referenced from:
                candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
          ld: symbol(s) not found for architecture arm64

This is likely due to a missing linker flag that was needed to enable the mkl library. You can try adding the following for mkl at the top of your binary:

extern crate intel_mkl_src;

or for accelerate:

extern crate accelerate_src;

Cannot run the LLaMA examples: access to source requires login credentials

Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

This is likely because you're not permissioned for the LLaMA-v2 model. To fix this, you have to register on the huggingface-hub, accept the LLaMA-v2 model conditions, and set up your authentication token. See issue #350 for more details.

Missing cute/cutlass headers when compiling flash-attn

  In file included from kernels/flash_fwd_launch_template.h:11:0,
                   from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
  kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
   #include <cute/algorithm/copy.hpp>
            ^~~~~~~~~~~~~~~~~~~~~~~~~
  compilation terminated.
  Error: nvcc error while compiling:

cutlass is provided as a git submodule so you may want to run the following command to check it in properly.

git submodule update --init

Compiling with flash-attention fails

/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:

This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.

env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...

Linking error on windows when running rustdoc or mdbook tests

Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
 = note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'

Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:

mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib

Extremely slow model load time with WSL

This may be caused by the models being loaded from /mnt/c, more details on stackoverflow.

Tracking down errors

You can set RUST_BACKTRACE=1 to be provided with backtraces when a candle error is generated.