huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.88k stars 963 forks source link

Performance issues compared to Pytorch #1139

Open joeyballentine opened 1 year ago

joeyballentine commented 1 year ago

Hello. I mentioned this in the discord and worked with a member to make sure I wasn't doing anything dumb. I tested the release version of my candle code with cudnn enabled vs equivalent pytorch code, and comparatively candle is about 4x slower.

I have attached the code I was using to compare. It contains both the original python/pytorch implementation of the RealESRGAN RRDBNet arch, as well as my Candle implementation.

I'm limited on time or I would have set up a proper repo for this with a script/program that would run both tests automatically, but this is the best I can do at the moment. In order to use either script, you'll probably have to adjust the paths in each script to match the path of the model and your test images (I did not include test images). I recommend trying ~10 smallish images (128x for example).

Context from discord: https://discord.com/channels/879548962464493619/1136218819447238726/1164985040854339736

Code: rust_candle_test.zip

And the model (had to upload to drive) https://drive.google.com/file/d/1AyvArWkR3qonMV2pBtk3zDkct0yJrh5Z/view?usp=sharing

For reference, here is the results of when I benchmarked it:

PyTorch:

Model took 323.999ms // First run, takes a long time
Saved 00000.png
Model took 35.4905ms // Second run, and subsequent runs after, take significantly less
Saved 00001.png

Candle:

Model took 262.1319ms // First run, takes a while but less time than torch
Saved 00000.png
Model took 124.97ms // Second run, and subsequent runs after, takes less but still far more than pytorch
Saved 00001.png

Please let me know if you need or want any more information. Candle is a very interesting project and seems very promising. It just currently doesn't seem to have as much optimization as pytorch.

LaurentMazare commented 1 year ago

Tagging this with "help wanted" in case anyone wanted to cut their teeth on profiling cuda kernels. The idea would be first to replicate the slowness locally, ideally with torch.backends.cudnn.benchmark set to False on the PyTorch side to avoid the cudnn method selection, and then to run both the pytorch and candle version within nsys and look where the time is actually spent on both sides.

jeromeku commented 1 year ago

@LaurentMazare @joeyballentine

Did some quick profiling using nsys.

TLDR, here are the total time (in ns) by kernel and framework: kernel framework total
cudnn_ampere_scudnn_winograd_128x128_ldg1_ldg4_rel pytorch 8.55e+09
void at::native::::CatArrayBatchedCopy_al pytorch 5.87e+09
sm86_xmma_fprop_implicit_gemm_tf32f32_tf32f32f32 pytorch 4.35e+09
void at::native::elementwise_kernel<(int)128, (int pytorch 2.30e+09
void cudnn::ops::nchwToNhwcKernel pytorch 1.98e+09
void at::native::vectorized_elementwise_kernel<(in pytorch 1.67e+09
void at::native::vectorized_elementwise_kernel<(in pytorch 1.31e+09
void at::native::vectorized_elementwise_kernel<(in pytorch 8.72e+08
cudnn_ampere_scudnn_128x32_relu_small_nn_v1 pytorch 3.37e+08
void at::native::::upsample_nearest2d_out pytorch 1.22e+08
void cudnn::winograd::generateWinogradTilesKernel< pytorch 9.84e+07
void cutlass_cudnn::Kernel pytorch 1.65e+07
sm86_xmma_fprop_implicit_gemm_indexed_tf32f32_tf32 pytorch 1.45e+07
void cudnn::ops::nhwcToNchwKernel pytorch 7.99e+06
cudnn_ampere_scudnn_128x64_relu_small_nn_v1 pytorch 6.73e+06
void cask_cudnn::computeOffsetsKernel<(bool)0, (bo pytorch 1.85e+06
void at::native::elementwise_kernel<(int)128, (int pytorch 5.62e+05
kernel framework total
ucopy_f32 rust 1.42e+10
badd_f32 rust 1.14e+10
cudnn_infer_ampere_scudnn_winograd_128x128_ldg1_ld rust 5.96e+09
affine_f32 rust 3.78e+09
void cudnn::ops::nchwToNhwcKernel rust 3.69e+09
bmaximum_f32 rust 3.69e+09
bminimum_f32 rust 3.68e+09
void cutlass_cudnn_infer::Kernel rust 2.66e+09
void cutlass_cudnn_infer::Kernel rust 2.60e+09
void cudnn::ops::nhwcToNchwKernel rust 1.21e+09
upsample_nearest2d_f32 rust 6.36e+08
void cutlass_cudnn_infer::Kernel rust 5.87e+08
cudnn_infer_ampere_scudnn_128x32_relu_small_nn_v1 rust 3.34e+08
void cudnn::winograd::generateWinogradTilesKernel< rust 7.22e+07
cudnn_infer_ampere_scudnn_128x64_relu_small_nn_v1 rust 6.83e+06
void cask_cudnn_infer::computeOffsetsKernel<(bool) rust 1.83e+06

Seems like some of the custom unary and binary ops in candle_kernels is an issue.

I added nvtx annotations to focus the profiling on the model:

The above table is only for ForwardPass.

Here are the total times for these ranges: framework range total
pytorch Iteration 3.20e+10
pytorch ForwardPass 1.30e+10
pytorch cuBLAS:cublasCreate_v2 4.08e+06
rust Iteration 8.43e+10
rust ForwardPass 5.62e+10
rust cuBLAS:cublasCreate_v2 1.83e+05

Additional notes:

rust_candle_test.zip notebook.zip analysis.zip

@LaurentMazare lmk if further kernel profiling / benchmarking makes sense as pertains to candle more generally.

LaurentMazare commented 1 year ago

Thanks @jeromeku for the very detailed analysis and providing the files and instructions to replicate, that's a great contribution to candle and will certainly inspire others to do some profiling. Certainly do feel free to do more of these. Not sure what the best way is to give this some visibility, maybe some writeups/tutorials that we link to in the main readme (and that could also be advertised on r/rust etc).

For the curious, the notebook as a gist.

And indeed the low hanging fruit in this case seem to be the ucopy and badd kernels. These are implemented in a very naive way, in particular ucopy when called on a non-contiguous array would just compute the "strided index" for every position and for this each time do some modulo etc computations without trying to re-use the previous results. This should be easy to optimize at least for specific use cases (where the final dimension is already contiguous for example, we should use a memcpy like cuda operation to accelerate copies). Probably worth looking at how pytorch/other frameworks do it and taking some inspiration there - I can have a look though it will be at least a week as I'm moving at the moment and don't have my desktop computer at hand anymore.

jeromeku commented 1 year ago

Some thoughts @LaurentMazare:

jeromeku commented 1 year ago

@LaurentMazare

Did some quick digging into pytorch ops:

Looking into cutlass also, primarily copy atom and algorithm.

Will poke around some more and see which parts can be re-purposed.

dlfrnaos19 commented 1 year ago

I tested a simple benchmark using PyTorch and Rust’s candle. I performed matmul, increasing the number of repetitions on a tensor of size (1000,1000) in float32 format.

torch n=99, 0.183099s n=999, 1.894550s n=9999, 18.856167

candle n=99 0.214144s n=999 1.782565s n=9999 17.811970

python code

import torch
from tqdm import tqdm
import time

torch_sample = torch.rand(1000, 1000, dtype=torch.float32).to("cpu")
torch_sample

n = 9999

torch_start = time.time()
for i in tqdm(range(n)):
    torch.matmul(torch_sample, torch_sample)
torch_end = time.time()

torch_time = torch_end - torch_start

print(f"torch time: {torch_time} s")

rust code

use candle_core::Tensor;
use candle_core::Device::Cpu;
use std::time::Instant;
use candle_core::Result;

fn main() -> Result<()> {
    let device = Cpu;
    let candle_sample = Tensor::rand(0f32, 1f32, (1000,1000), &device)?;

    let n = 9999;
    let candle_start = Instant::now();
    for _ in 0..n {
        candle_sample.matmul(&candle_sample)?;
    }

    let candle_end = candle_start.elapsed().as_secs_f32();
    eprintln!("candle_end = {:?}", candle_end);
    Ok(())
}

rust command -> cargo run --release

cpu Ryzen Threadripper PRO 3955WX 16-Cores

jeromeku commented 1 year ago

@LaurentMazare

Did some more research on pytorch internals.

Some notes:

Lots to take inspiration from, starting with copy, unary, and binary kernels and beyond -- look forward to helping with these improvements!

edgarriba commented 11 months ago

I tested a simple benchmark using PyTorch and Rust’s candle. I performed matmul, increasing the number of repetitions on a tensor of size (1000,1000) in float32 format.

@dlfrnaos19 not sure if this this the right place to comment about this, but have you tried with small vectors/matrices ? there are some specific needs e.g in robotics and geometric vision & graphics where pytorch does a real bad job computing fast dot products or chaining pose transforms (Vec3, Mat3x3, Mat4x4, etc) because of all complex dispatching mechanism that takes more time than the actual kernel execution. Reference discussion: https://github.com/pytorch/pytorch/issues/103313

Would be great that candle can solve that.

dlfrnaos19 commented 11 months ago

@edgarriba I have reviewed the issue you posted. The test I conducted was a simple repetition of the matmul operation in a for loop with the 1000x1000 matrix mentioned in the issue. If the matrix is smaller than this, please specify the shape of the matrix so that I can conduct a test.

edgarriba commented 11 months ago

@dlfrnaos19 i'm interested in small matrices/vector operators e.g 3x1, 4x1, 3x3, 4x4 (most common shapes used in robotics/vision/graphics) that involves camera pose and projections. Happy to contribute in this direction.

dlfrnaos19 commented 11 months ago

@edgarriba don`t know if this what you wanted

3x3 tensor, ten milion times matmul result

torch.mm, cpu 55.3527s torch.matmul, cpu, 55.5209s torch.mm, gpu(a6000), 112.1874s numpy.dot, cpu, 14.3464s numba jit + numpy, cpu, 3.1331s candle, cpu, 2.7602s candle, gpu, 0.0005s

It looks very fast, but please check it again.

fn cpu_test() {
    let a = Tensor::randn(0.1f32, 1f32, (10000000,3,3), &Cpu);
    let b = Tensor::randn(0.1f32, 1f32, (10000000, 3,3), &Cpu);

    let start = Instant::now();
    let res: Vec<_> = zip(a,b)
    .into_iter()
    .map(move |(x, y)| {x.matmul(&y).unwrap();})
    .collect();

    let end = start.elapsed().as_secs_f32();
    eprintln!("end = {:?}", end);
    eprintln!("res = {:?}", res);
}
fn gpu_test() {

    let device = Device::new_cuda(0).unwrap();

    let a = Tensor::randn(0.1f32, 1f32, (10000000,3,3), &device);
    let b = Tensor::randn(0.1f32, 1f32, (10000000, 3,3), &device);

    let start = Instant::now();
    let res: Vec<_> = zip(a,b)
    .into_iter()
    .map(move |(x, y)| {x.matmul(&y).unwrap();})
    .collect();
    let end = start.elapsed().as_secs_f32();
    eprintln!("end = {:?}", end);
    eprintln!("res = {:?}", res);
}
tensor_cpu = torch.randn(10000000,3,3)
tensor_cpu2 = torch.randn(10000000,3,3)

start_time = time.time()

for a, b in tqdm(zip(tensor_cpu, tensor_cpu2), total=tensor_cpu.shape[0]):
    result_cpu = torch.mm(a,b)

end_time = time.time()
elapsed_time = end_time - start_time
tensor_gpu = torch.randn(10000000,3, 3, device='cuda')
tensor_gpu2 = torch.randn(10000000,3, 3, device='cuda')

start_time = time.time()

for a, b in tqdm(zip(tensor_gpu, tensor_gpu2), total=tensor_gpu.shape[0]):
    result_gpu = torch.mm(a,b)

end_time = time.time()
elapsed_time = end_time - start_time
numpy_tensor1 = np.random.rand(10000000, 3, 3)
numpy_tensor2 = np.random.rand(10000000, 3, 3)

start_time = time.time()

for a, b in tqdm(zip(numpy_tensor1, numpy_tensor2), total=numpy_tensor2.shape[0]):
    numpy_result = np.dot(a,b)
    end_time = time.time()
    elapsed_time = end_time - start_time
@jit(nopython=True)
def main(numpy_tensor1, numpy_tensor2):

    for a, b in zip(numpy_tensor1, numpy_tensor2):
        numpy_result = np.dot(a,b)

numpy_tensor1 = np.random.rand(10000000, 3, 3)
numpy_tensor2 = np.random.rand(10000000, 3, 3)
start_time = time.time()
main(numpy_tensor1=numpy_tensor1, numpy_tensor2=numpy_tensor2)
end_time = time.time()
elapsed_time = end_time - start_time
edgarriba commented 11 months ago

@dlfrnaos19 the results are very impressive! thanks so much!

lamnguyenx commented 10 months ago

@dlfrnaos19 : I think you shouldn't wrap the tensor iterator within tqdm in this case. It increases the total exec time.

joeyballentine commented 10 months ago

By the way, I forgot to mention this when I found out, but in my initial benchmark I was using a custom sequential layer that was a bit inefficient. I meant to redo my benchmark after I fixed that, but I completely forgot to. I will try to remember to do that and verify that it is still slower.

super-fun-surf commented 2 months ago

this is great work y'all. is this ongoing? and is there a repo for benchmarks or results?

ivanstepanovftw commented 2 months ago

@dlfrnaos19

torch.mm, gpu(a6000), 112.1874s

for a, b in tqdm(zip(tensor_gpu, tensor_gpu2), total=tensor_gpu.shape[0]):
    result_gpu = torch.mm(a,b)

Man, this way you are benchmarking CPython. You should perform batched matrix multiplication.

result_cpu = torch.bmm(tensor_cpu, tensor_cpu2)

Requesting new benchmarks.

ivanstepanovftw commented 2 months ago

torch.bmm

import torch
import time

# Creating random tensors
tensor_cpu = torch.randn(10000000, 3, 3)
tensor_cpu2 = torch.randn(10000000, 3, 3)

# Measuring time
start_time = time.time()

# Batch matrix multiplication
result_cpu = torch.bmm(tensor_cpu, tensor_cpu2)

end_time = time.time()
elapsed_time = end_time - start_time

print(f"Elapsed time: {elapsed_time:.4f} seconds")

As for Candle, I did not manage to compile your code, because nor Iterator implemented, nor .iter found for Tensor. Tried to make it with indexing:

Cargo.toml

[package]
name = "candle_test"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.1" }
Old code

```rust use candle_core::Tensor; use std::time::Instant; use candle_core::Device::Cpu; use candle_core::IndexOp; fn cpu_test() { // Create two random tensors on the CPU let a = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap(); let b = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap(); // Measure the elapsed time for tensor operations let start = Instant::now(); // Create an empty tensor to store results // let mut res = Vec::with_capacity(a.shape()[0]); for i in 0..a.shape().dims()[0] { // Extract slices of tensors let a_slice = a.i(i).unwrap(); let b_slice = b.i(i).unwrap(); // Perform matrix multiplication let result = a_slice.matmul(&b_slice).unwrap(); // res.push(result); } let end = start.elapsed().as_secs_f32(); eprintln!("Elapsed time = {:?}", end); // eprintln!("Result length = {:?}", res.len()); // Avoid printing large results } fn main() { cpu_test(); } ```

src/main.rs

use candle_core::Device::Cpu;
use candle_core::Tensor;
use std::time::Instant;

fn cpu_test() {
    // Create two random tensors on the CPU
    let a = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();
    let b = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();

    // Measure the elapsed time for tensor operations
    let start = Instant::now();

    let _result = a.matmul(&b).unwrap();

    let end = start.elapsed().as_secs_f32();

    eprintln!("Elapsed time = {:?}", end);
    // eprintln!("Result length = {:?}", res.len()); // Avoid printing large results
}
fn main() {
    cpu_test();
}

And run with cargo run --release

Results: cpu, torch.bmm: Elapsed time: 0.1386 seconds cpu, numpy.dot: Elapsed time: 33.0611 seconds cpu, numpy.dot + numba: Elapsed time: 4.3504 seconds cpu, HangingFace/Candle: Elapsed time = 3.8913457

Not sure why Candle is so slow. Or why Rust is so slow. Tracing?

Details

![image](https://github.com/user-attachments/assets/fe4693f7-2572-4027-b5a9-12f1128a4d89)

EricLBuehler commented 2 months ago

@ivanstepanovftw Tensor::matmul allows batching support, similar to torch.bmm. Something like:

use candle_core::Device::Cpu;
use candle_core::Tensor;
use std::time::Instant;

fn cpu_test() {
    // Create two random tensors on the CPU
    let a = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();
    let b = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();

    // Measure the elapsed time for tensor operations
    let start = Instant::now();

    let _result = a.matmul(&b).unwrap();

    let end = start.elapsed().as_secs_f32();

    eprintln!("Elapsed time = {:?}", end);
    // eprintln!("Result length = {:?}", res.len()); // Avoid printing large results
}
fn main() {
    cpu_test();
}

Will be much faster.

edgarriba commented 2 months ago

@ivanstepanovftw my original request **above was to verify small matrices ops perf (not in batch), which i think makes the previous test still valid.

**https://github.com/huggingface/candle/issues/1139#issuecomment-1833721335

ivanstepanovftw commented 2 months ago

@EricLBuehler, thank you, I did not check if this was possible. I have updated code in my comment and modified Candle result to match torch.bmm. Btw, results are still weak.

@edgarriba, have you checked dfdx? I am happy if someone will provide benchmark for that. I also will keep edit history for previous results with loop. Actually, you maybe need something like ONNX runtime for that task. Or, I don't know, ggml library?