Open agerasev opened 3 weeks ago
RmsNorm switches to faster implementation if tensor is contiguous:
RmsNorm
https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/layer_norm.rs#L174-L175
But it does not support backward pass:
https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/ops.rs#L640
Maybe it's better to implement ModuleT rather than Module for RmsNorm and use faster implementation only if train == false?
ModuleT
Module
train == false
RmsNorm
switches to faster implementation if tensor is contiguous:https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/layer_norm.rs#L174-L175
But it does not support backward pass:
https://github.com/huggingface/candle/blob/82b641fd2752e3b14db6a9c91faef70e3329f3b5/candle-nn/src/ops.rs#L640
Maybe it's better to implement
ModuleT
rather thanModule
forRmsNorm
and use faster implementation only iftrain == false
?