huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

Fix sigmoid gradient calculation and move sigmoid into a specialized op #2114

Closed MilkFather closed 2 weeks ago

MilkFather commented 3 weeks ago

This pull request is inspired by #1722 and may supersede it.

As reported in #1722, calculating the gradient of sigmoid function when input is a large negative f32 may lead to NaNs.

input =
[ -548.3162,  -201.7132,   -74.2032,   -27.2899,   -10.0179,    -3.6269,    -1.1752,     0.0000,     1.1752,     3.6269,    10.0179,    27.2899,    74.2032,   201.7132,   548.3162]
sigmoid =
[  0.0000e0,   0.0000e0, 5.9424e-33, 1.4065e-12,  4.4592e-5,  2.5909e-2,  2.3592e-1,  5.0000e-1,  7.6408e-1,  9.7409e-1,  9.9996e-1,   1.0000e0,   1.0000e0,   1.0000e0,   1.0000e0]
grad =
[       NaN,        NaN,   0.0000e0, 1.4065e-12,  4.4591e-5,  2.5238e-2,  1.8026e-1,  2.5000e-1,  1.8026e-1,  2.5238e-2,  4.4591e-5, 1.4065e-12, 5.9424e-33,   0.0000e0,   0.0000e0]

This pull request solves the NaN problem by exploiting a property of the derivative of sigmoid function, specifically

d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)

In this approach, calculating the derivative of sigmoid function relies solely on the already-calculated sigmoid function value and subtractive and multiplicative operations, effectively avoiding NaNs. Furthermore, the number of operations is significantly reduced, leading to ~2.5x performance boost.

Code snippet used to test the proposed approach ```rust use std::{hint::black_box, time::Instant}; use candle_core::{Tensor, Var, Device}; use candle_nn::ops::sigmoid; fn main() { let device = Device::Cpu; let x = Var::new(&[-548.3162_f32, -201.7132, -74.2032, -27.2899, -10.0179, -3.6269, -1.1752, 0.0000, 1.1752, 3.6269, 10.0179, 27.2899, 74.2032, 201.7132, 548.3162], &device).unwrap(); let y = sigmoid(&x).unwrap(); let grads = y.backward().unwrap(); let x_grad = grads.get(&x).unwrap(); println!("{}", y); println!("{}", x_grad); let start_time = Instant::now(); for _ in 0..100000 { let y = black_box(sigmoid(&x).unwrap()); let grads = black_box(y.backward().unwrap()); let _ = black_box(grads.get(&x).unwrap()); } let end_time = Instant::now(); let duration = end_time - start_time; println!("avg iter: {} microseconds", duration.as_micros() as f64 / 100000.0); } ```

Output using candle-core 0.4.1 and candle-nn 0.4.1

[  0.0000e0,   0.0000e0, 5.9424e-33, 1.4065e-12,  4.4592e-5,  2.5909e-2,
  2.3592e-1,  5.0000e-1,  7.6408e-1,  9.7409e-1,  9.9996e-1,   1.0000e0,
   1.0000e0,   1.0000e0,   1.0000e0]
Tensor[[15], f32]
[       NaN,        NaN,   0.0000e0, 1.4065e-12,  4.4591e-5,  2.5238e-2,
  1.8026e-1,  2.5000e-1,  1.8026e-1,  2.5238e-2,  4.4591e-5, 1.4065e-12,
 5.9424e-33,   0.0000e0,   0.0000e0]
Tensor[[15], f32]
avg iter: 16.92021 microseconds

Output using the proposed approach

[  0.0000e0,   0.0000e0, 5.9424e-33, 1.4065e-12,  4.4592e-5,  2.5909e-2,
  2.3592e-1,  5.0000e-1,  7.6408e-1,  9.7409e-1,  9.9996e-1,   1.0000e0,
   1.0000e0,   1.0000e0,   1.0000e0]
Tensor[[15], f32]
[  0.0000e0,   0.0000e0, 5.9424e-33, 1.4065e-12,  4.4591e-5,  2.5238e-2,
  1.8026e-1,  2.5000e-1,  1.8026e-1,  2.5238e-2,  4.4582e-5,   0.0000e0,
   0.0000e0,   0.0000e0,   0.0000e0]
Tensor[[15], f32]
avg iter: 6.57586 microseconds

Note that although there are numerical inaccuracies in gradients for very large positive numbers, it might be fine because it is still better than NaNs.

I am not sure whether I should write CUDA and Metal kernels. Also, due to the introduction of a customized gradient calculation routine, sigmoid is now a specialized op for Tensor.

LaurentMazare commented 2 weeks ago

Could you use a CustomOp that would be defined in candle_nn only for this? This would avoid adding more complexity to candle-core. Also we will want to have cuda and metal kernels for this as otherwise the models that currently use the sigmoid function won't work when run on these devices.

MilkFather commented 2 weeks ago

I have switched to a CustomOp-based implementation, although CPU-only. Curiously, even if I compile the above test script with metal feature flag on my MacBook, the script does not fail for not having a metal implementation.

I am investigating how to integrate cuda and metal code into the codebase, and most importantly, how to call them.

LaurentMazare commented 2 weeks ago

Metal and cuda implementation are optional but it will fail at runtime if one tries to use them on these devices. You can find examples in this ops.rs file that have custom cuda and metal kernels.

MilkFather commented 2 weeks ago

@LaurentMazare I added an experimental metal kernel implementation. I admit it is a pain to write kernels and supporting code. However, the code does not run properly on my Mac, as all output seems zero. Maybe my Mac is too old, and if you have a device please test it for me.

EDIT: I have also added a cuda implenmentation. The cuda impl works fine to me, good to double check.

Note: I heavily borrowed the code from candle-core, to the extent that I even copied struct definitions to make everything work. This makes me wonder whether we should open up more APIs to enable external tensor unary/binary operations. For example, we could treat every primitive operation indiscriminately such that they are all implemented as a struct + CustomOp trait impl.

LaurentMazare commented 2 weeks ago

Thanks, looks pretty good to me, I've added a test and it worked well on my macbook M2 pro. I've also exposed a bit more of the cuda internals so as to reduce the duplication, it's certainly not perfect but at least that will do for this one.