Open joeyballentine opened 1 year ago
Tagging this with "help wanted" in case anyone wanted to cut their teeth on profiling cuda kernels. The idea would be first to replicate the slowness locally, ideally with torch.backends.cudnn.benchmark
set to False
on the PyTorch side to avoid the cudnn method selection, and then to run both the pytorch and candle version within nsys and look where the time is actually spent on both sides.
@LaurentMazare @joeyballentine
Did some quick profiling using nsys
.
TLDR, here are the total time (in ns) by kernel and framework: |
kernel | framework | total |
---|---|---|---|
cudnn_ampere_scudnn_winograd_128x128_ldg1_ldg4_rel | pytorch | 8.55e+09 | |
void at::native::::CatArrayBatchedCopy_al | pytorch | 5.87e+09 | |
sm86_xmma_fprop_implicit_gemm_tf32f32_tf32f32f32 | pytorch | 4.35e+09 | |
void at::native::elementwise_kernel<(int)128, (int | pytorch | 2.30e+09 | |
void cudnn::ops::nchwToNhwcKernel | pytorch | 1.98e+09 | |
void at::native::vectorized_elementwise_kernel<(in | pytorch | 1.67e+09 | |
void at::native::vectorized_elementwise_kernel<(in | pytorch | 1.31e+09 | |
void at::native::vectorized_elementwise_kernel<(in | pytorch | 8.72e+08 | |
cudnn_ampere_scudnn_128x32_relu_small_nn_v1 | pytorch | 3.37e+08 | |
void at::native::::upsample_nearest2d_out | pytorch | 1.22e+08 | |
void cudnn::winograd::generateWinogradTilesKernel< | pytorch | 9.84e+07 | |
void cutlass_cudnn::Kernel | pytorch | 1.65e+07 | |
sm86_xmma_fprop_implicit_gemm_indexed_tf32f32_tf32 | pytorch | 1.45e+07 | |
void cudnn::ops::nhwcToNchwKernel | pytorch | 7.99e+06 | |
cudnn_ampere_scudnn_128x64_relu_small_nn_v1 | pytorch | 6.73e+06 | |
void cask_cudnn::computeOffsetsKernel<(bool)0, (bo | pytorch | 1.85e+06 | |
void at::native::elementwise_kernel<(int)128, (int | pytorch | 5.62e+05 |
kernel | framework | total |
---|---|---|
ucopy_f32 | rust | 1.42e+10 |
badd_f32 | rust | 1.14e+10 |
cudnn_infer_ampere_scudnn_winograd_128x128_ldg1_ld | rust | 5.96e+09 |
affine_f32 | rust | 3.78e+09 |
void cudnn::ops::nchwToNhwcKernel | rust | 3.69e+09 |
bmaximum_f32 | rust | 3.69e+09 |
bminimum_f32 | rust | 3.68e+09 |
void cutlass_cudnn_infer::Kernel | rust | 2.66e+09 |
void cutlass_cudnn_infer::Kernel | rust | 2.60e+09 |
void cudnn::ops::nhwcToNchwKernel | rust | 1.21e+09 |
upsample_nearest2d_f32 | rust | 6.36e+08 |
void cutlass_cudnn_infer::Kernel | rust | 5.87e+08 |
cudnn_infer_ampere_scudnn_128x32_relu_small_nn_v1 | rust | 3.34e+08 |
void cudnn::winograd::generateWinogradTilesKernel< | rust | 7.22e+07 |
cudnn_infer_ampere_scudnn_128x64_relu_small_nn_v1 | rust | 6.83e+06 |
void cask_cudnn_infer::computeOffsetsKernel<(bool) | rust | 1.83e+06 |
Seems like some of the custom unary
and binary
ops in candle_kernels
is an issue.
I added nvtx
annotations to focus the profiling on the model:
Iteration
: start / end of each iteration which includes image loading and saving,ForwardPass
: start / end of each model forward pass. The above table is only for ForwardPass
.
Here are the total times for these ranges: | framework | range | total |
---|---|---|---|
pytorch | Iteration | 3.20e+10 | |
pytorch | ForwardPass | 1.30e+10 | |
pytorch | cuBLAS:cublasCreate_v2 | 4.08e+06 | |
rust | Iteration | 8.43e+10 | |
rust | ForwardPass | 5.62e+10 | |
rust | cuBLAS:cublasCreate_v2 | 1.83e+05 |
Additional notes:
inputs
folder. This is the images
folder referenced in the python / rust code.nsys-profile.sh
as well as a Makefile
that automates profiling the respective frameworks with a few knobs (e.g., whether or not to include cpu-sampling).analysis
) as well as a python notebook to parse / summarize the data. More data in analysis
than presented here in case you're interested in diving further. The notebook shows how to load this additional data.rust_candle_test.zip notebook.zip analysis.zip
@LaurentMazare lmk if further kernel profiling / benchmarking makes sense as pertains to candle
more generally.
Thanks @jeromeku for the very detailed analysis and providing the files and instructions to replicate, that's a great contribution to candle and will certainly inspire others to do some profiling. Certainly do feel free to do more of these. Not sure what the best way is to give this some visibility, maybe some writeups/tutorials that we link to in the main readme (and that could also be advertised on r/rust etc).
For the curious, the notebook as a gist.
And indeed the low hanging fruit in this case seem to be the ucopy
and badd
kernels. These are implemented in a very naive way, in particular ucopy
when called on a non-contiguous array would just compute the "strided index" for every position and for this each time do some modulo etc computations without trying to re-use the previous results. This should be easy to optimize at least for specific use cases (where the final dimension is already contiguous for example, we should use a memcpy like cuda operation to accelerate copies). Probably worth looking at how pytorch/other frameworks do it and taking some inspiration there - I can have a look though it will be at least a week as I'm moving at the moment and don't have my desktop computer at hand anymore.
Some thoughts @LaurentMazare:
ops
. nsight compute
? The macro level data from nsys
seems to have already surfaced addressable issues.@LaurentMazare
Did some quick digging into pytorch
ops:
ATen/native/cuda
`Looking into cutlass
also, primarily copy
atom and algorithm.
Will poke around some more and see which parts can be re-purposed.
I tested a simple benchmark using PyTorch and Rust’s candle. I performed matmul, increasing the number of repetitions on a tensor of size (1000,1000) in float32 format.
torch n=99, 0.183099s n=999, 1.894550s n=9999, 18.856167
candle n=99 0.214144s n=999 1.782565s n=9999 17.811970
python code
import torch
from tqdm import tqdm
import time
torch_sample = torch.rand(1000, 1000, dtype=torch.float32).to("cpu")
torch_sample
n = 9999
torch_start = time.time()
for i in tqdm(range(n)):
torch.matmul(torch_sample, torch_sample)
torch_end = time.time()
torch_time = torch_end - torch_start
print(f"torch time: {torch_time} s")
rust code
use candle_core::Tensor;
use candle_core::Device::Cpu;
use std::time::Instant;
use candle_core::Result;
fn main() -> Result<()> {
let device = Cpu;
let candle_sample = Tensor::rand(0f32, 1f32, (1000,1000), &device)?;
let n = 9999;
let candle_start = Instant::now();
for _ in 0..n {
candle_sample.matmul(&candle_sample)?;
}
let candle_end = candle_start.elapsed().as_secs_f32();
eprintln!("candle_end = {:?}", candle_end);
Ok(())
}
rust command -> cargo run --release
cpu Ryzen Threadripper PRO 3955WX 16-Cores
@LaurentMazare
Did some more research on pytorch
internals.
Some notes:
Actual kernel implementations are in ATen/native/cuda
, though much of the call chain (e.g., torch.add
) is obscured behind a long chain of dispatching calls that route the call depending on dtype
, backend
, tensor traits
, etc.
nsys
CUDA
trace that shows this call chain that ultimately results in a vectorized elementwise kernel.
Copy Ops - these live here ATen/native/cuda/Copy.cu
.
copy_kernel_cuda
implements the specific kind of copy -- device<->device, CPU<->device, etc.Especially instructive is cuda_vectorized_test.cu, which demonstrates various vectorized memory ops: alignment checks, vectorized copy kernel, etc.
Unary, Binary, and other common Ops
ATen/native/cuda/CUDALoops.cuh
. gpu_kernel_impl
is the high level function that routes the call to the appropriate launch params and kernel for elementwise ops.Aten/native/cuda/MemoryAccess.cuh
and ATen/native/cuda/Loops.cuh
.Unary*Kernels.cu
, Binary*Kernels.cu
, and other aptly named files for Loss
, Activations
, etc. within the ATen/native/cuda
directory.add
is handled a bit differently -- the kernel is here ATen/native/ufunc/add.h
-- though most of the efficiently handling of this is in CUDALoops.cuh
per above. pytorch
codegenCodegen
Lots to take inspiration from, starting with copy
, unary
, and binary
kernels and beyond -- look forward to helping with these improvements!
I tested a simple benchmark using PyTorch and Rust’s candle. I performed matmul, increasing the number of repetitions on a tensor of size (1000,1000) in float32 format.
@dlfrnaos19 not sure if this this the right place to comment about this, but have you tried with small vectors/matrices ? there are some specific needs e.g in robotics and geometric vision & graphics where pytorch does a real bad job computing fast dot products or chaining pose transforms (Vec3, Mat3x3, Mat4x4, etc) because of all complex dispatching mechanism that takes more time than the actual kernel execution. Reference discussion: https://github.com/pytorch/pytorch/issues/103313
Would be great that candle can solve that.
@edgarriba I have reviewed the issue you posted. The test I conducted was a simple repetition of the matmul operation in a for loop with the 1000x1000 matrix mentioned in the issue. If the matrix is smaller than this, please specify the shape of the matrix so that I can conduct a test.
@dlfrnaos19 i'm interested in small matrices/vector operators e.g 3x1, 4x1, 3x3, 4x4 (most common shapes used in robotics/vision/graphics) that involves camera pose and projections. Happy to contribute in this direction.
@edgarriba don`t know if this what you wanted
3x3 tensor, ten milion times matmul result
torch.mm, cpu 55.3527s torch.matmul, cpu, 55.5209s torch.mm, gpu(a6000), 112.1874s numpy.dot, cpu, 14.3464s numba jit + numpy, cpu, 3.1331s candle, cpu, 2.7602s candle, gpu, 0.0005s
It looks very fast, but please check it again.
fn cpu_test() {
let a = Tensor::randn(0.1f32, 1f32, (10000000,3,3), &Cpu);
let b = Tensor::randn(0.1f32, 1f32, (10000000, 3,3), &Cpu);
let start = Instant::now();
let res: Vec<_> = zip(a,b)
.into_iter()
.map(move |(x, y)| {x.matmul(&y).unwrap();})
.collect();
let end = start.elapsed().as_secs_f32();
eprintln!("end = {:?}", end);
eprintln!("res = {:?}", res);
}
fn gpu_test() {
let device = Device::new_cuda(0).unwrap();
let a = Tensor::randn(0.1f32, 1f32, (10000000,3,3), &device);
let b = Tensor::randn(0.1f32, 1f32, (10000000, 3,3), &device);
let start = Instant::now();
let res: Vec<_> = zip(a,b)
.into_iter()
.map(move |(x, y)| {x.matmul(&y).unwrap();})
.collect();
let end = start.elapsed().as_secs_f32();
eprintln!("end = {:?}", end);
eprintln!("res = {:?}", res);
}
tensor_cpu = torch.randn(10000000,3,3)
tensor_cpu2 = torch.randn(10000000,3,3)
start_time = time.time()
for a, b in tqdm(zip(tensor_cpu, tensor_cpu2), total=tensor_cpu.shape[0]):
result_cpu = torch.mm(a,b)
end_time = time.time()
elapsed_time = end_time - start_time
tensor_gpu = torch.randn(10000000,3, 3, device='cuda')
tensor_gpu2 = torch.randn(10000000,3, 3, device='cuda')
start_time = time.time()
for a, b in tqdm(zip(tensor_gpu, tensor_gpu2), total=tensor_gpu.shape[0]):
result_gpu = torch.mm(a,b)
end_time = time.time()
elapsed_time = end_time - start_time
numpy_tensor1 = np.random.rand(10000000, 3, 3)
numpy_tensor2 = np.random.rand(10000000, 3, 3)
start_time = time.time()
for a, b in tqdm(zip(numpy_tensor1, numpy_tensor2), total=numpy_tensor2.shape[0]):
numpy_result = np.dot(a,b)
end_time = time.time()
elapsed_time = end_time - start_time
@jit(nopython=True)
def main(numpy_tensor1, numpy_tensor2):
for a, b in zip(numpy_tensor1, numpy_tensor2):
numpy_result = np.dot(a,b)
numpy_tensor1 = np.random.rand(10000000, 3, 3)
numpy_tensor2 = np.random.rand(10000000, 3, 3)
start_time = time.time()
main(numpy_tensor1=numpy_tensor1, numpy_tensor2=numpy_tensor2)
end_time = time.time()
elapsed_time = end_time - start_time
@dlfrnaos19 the results are very impressive! thanks so much!
@dlfrnaos19 : I think you shouldn't wrap the tensor iterator within tqdm in this case. It increases the total exec time.
By the way, I forgot to mention this when I found out, but in my initial benchmark I was using a custom sequential layer that was a bit inefficient. I meant to redo my benchmark after I fixed that, but I completely forgot to. I will try to remember to do that and verify that it is still slower.
this is great work y'all. is this ongoing? and is there a repo for benchmarks or results?
@dlfrnaos19
torch.mm, gpu(a6000), 112.1874s
for a, b in tqdm(zip(tensor_gpu, tensor_gpu2), total=tensor_gpu.shape[0]): result_gpu = torch.mm(a,b)
Man, this way you are benchmarking CPython. You should perform batched matrix multiplication.
result_cpu = torch.bmm(tensor_cpu, tensor_cpu2)
Requesting new benchmarks.
torch.bmm
import torch
import time
# Creating random tensors
tensor_cpu = torch.randn(10000000, 3, 3)
tensor_cpu2 = torch.randn(10000000, 3, 3)
# Measuring time
start_time = time.time()
# Batch matrix multiplication
result_cpu = torch.bmm(tensor_cpu, tensor_cpu2)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.4f} seconds")
As for Candle, I did not manage to compile your code, because nor Iterator
implemented, nor .iter
found for Tensor
. Tried to make it with indexing:
Cargo.toml
[package]
name = "candle_test"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.1" }
```rust use candle_core::Tensor; use std::time::Instant; use candle_core::Device::Cpu; use candle_core::IndexOp; fn cpu_test() { // Create two random tensors on the CPU let a = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap(); let b = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap(); // Measure the elapsed time for tensor operations let start = Instant::now(); // Create an empty tensor to store results // let mut res = Vec::with_capacity(a.shape()[0]); for i in 0..a.shape().dims()[0] { // Extract slices of tensors let a_slice = a.i(i).unwrap(); let b_slice = b.i(i).unwrap(); // Perform matrix multiplication let result = a_slice.matmul(&b_slice).unwrap(); // res.push(result); } let end = start.elapsed().as_secs_f32(); eprintln!("Elapsed time = {:?}", end); // eprintln!("Result length = {:?}", res.len()); // Avoid printing large results } fn main() { cpu_test(); } ```
src/main.rs
use candle_core::Device::Cpu;
use candle_core::Tensor;
use std::time::Instant;
fn cpu_test() {
// Create two random tensors on the CPU
let a = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();
let b = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();
// Measure the elapsed time for tensor operations
let start = Instant::now();
let _result = a.matmul(&b).unwrap();
let end = start.elapsed().as_secs_f32();
eprintln!("Elapsed time = {:?}", end);
// eprintln!("Result length = {:?}", res.len()); // Avoid printing large results
}
fn main() {
cpu_test();
}
And run with
cargo run --release
Results: cpu, torch.bmm: Elapsed time: 0.1386 seconds cpu, numpy.dot: Elapsed time: 33.0611 seconds cpu, numpy.dot + numba: Elapsed time: 4.3504 seconds cpu, HangingFace/Candle: Elapsed time = 3.8913457
Not sure why Candle is so slow. Or why Rust is so slow. Tracing?
![image](https://github.com/user-attachments/assets/fe4693f7-2572-4027-b5a9-12f1128a4d89)
@ivanstepanovftw Tensor::matmul
allows batching support, similar to torch.bmm
. Something like:
use candle_core::Device::Cpu;
use candle_core::Tensor;
use std::time::Instant;
fn cpu_test() {
// Create two random tensors on the CPU
let a = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();
let b = Tensor::randn(0.1f32, 1f32, (10000000, 3, 3), &Cpu).unwrap();
// Measure the elapsed time for tensor operations
let start = Instant::now();
let _result = a.matmul(&b).unwrap();
let end = start.elapsed().as_secs_f32();
eprintln!("Elapsed time = {:?}", end);
// eprintln!("Result length = {:?}", res.len()); // Avoid printing large results
}
fn main() {
cpu_test();
}
Will be much faster.
@ivanstepanovftw my original request **above was to verify small matrices ops perf (not in batch), which i think makes the previous test still valid.
**https://github.com/huggingface/candle/issues/1139#issuecomment-1833721335
@EricLBuehler, thank you, I did not check if this was possible. I have updated code in my comment and modified Candle result to match torch.bmm. Btw, results are still weak.
@edgarriba, have you checked dfdx? I am happy if someone will provide benchmark for that. I also will keep edit history for previous results with loop. Actually, you maybe need something like ONNX runtime for that task. Or, I don't know, ggml library?
Hello. I mentioned this in the discord and worked with a member to make sure I wasn't doing anything dumb. I tested the release version of my candle code with cudnn enabled vs equivalent pytorch code, and comparatively candle is about 4x slower.
I have attached the code I was using to compare. It contains both the original python/pytorch implementation of the RealESRGAN RRDBNet arch, as well as my Candle implementation.
I'm limited on time or I would have set up a proper repo for this with a script/program that would run both tests automatically, but this is the best I can do at the moment. In order to use either script, you'll probably have to adjust the paths in each script to match the path of the model and your test images (I did not include test images). I recommend trying ~10 smallish images (128x for example).
Context from discord: https://discord.com/channels/879548962464493619/1136218819447238726/1164985040854339736
Code: rust_candle_test.zip
And the model (had to upload to drive) https://drive.google.com/file/d/1AyvArWkR3qonMV2pBtk3zDkct0yJrh5Z/view?usp=sharing
For reference, here is the results of when I benchmarked it:
PyTorch:
Candle:
Please let me know if you need or want any more information. Candle is a very interesting project and seems very promising. It just currently doesn't seem to have as much optimization as pytorch.