huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.95k stars 768 forks source link

No backward pass for `RmsNorm` if tensor is contiguous #2168

Open agerasev opened 3 weeks ago

agerasev commented 3 weeks ago

RmsNorm switches to faster implementation if tensor is contiguous:

https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/layer_norm.rs#L174-L175

But it does not support backward pass:

https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/ops.rs#L640

Maybe it's better to implement ModuleT rather than Module for RmsNorm and use faster implementation only if train == false?