coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks
Other
1.73k stars 98 forks source link

Adding initializers #279

Open cBournhonesque opened 2 years ago

cBournhonesque commented 2 years ago

We currently have the trait ResetParams for modules, and Randomize for tensors.

I'm trying to see how we can make them more user-friendly, by having something similar to https://keras.io/api/layers/initializers/#randomnormal-class

1) Would it be worthwhile to have the ResetParams trait take a Distribution object that the user could provide? So that the user has more control over how the params of the networks are reset. (this could also be another trait)

2) Would it be useful to have a module initializers similar to Keras? Note that some initializers (Xavier) make use of the shape of the module that calls them.

coreylowman commented 2 years ago

Great question. Before ResetParams they just implemented the Randomize that tensors implement, but then you couldn't have things that depend on shape as you mentioned.

I wonder if it's possible to have some sort of fallback impl for initializers? Thinking along the lines of this:

// NOTE: use like: `StandardInitializers.reset_params(&mut model)`
pub struct StandardInitializers;
impl ResetParams<Linear<I, O>> for StandardInitializers { }

Then a user could override with:

pub struct MyLinearInitializers;
impl ResetParams<Linear<I, O>> for MyLinearInitializers { ... }
impl<M> ResetParams<M> for MyLinearInitializers where StandardInitializers: ResetParams<M> { ... }

However I think rust would error out with conflicting implementation error (since the StandardInitializers implements something for Linear<I, O> already). There's probably some way to disambiguate them, would have to think more about it though

coreylowman commented 1 year ago

This might be do-able with TensorCollection as it is now, especially since ResetParams was moved to doing this. Check out the existing implementation of ResetParams:

struct Resetter;
impl<E: Dtype, D: DeviceStorage> TensorVisitor<E, D> for Resetter {
    type Viewer = ViewTensorMut;
    type Err = D::Err;

    fn visit<S: Shape>(
        &mut self,
        _: String,
        opts: TensorOptions<S, E, D>,
        t: &mut Tensor<S, E, D>,
    ) -> Result<(), D::Err> {
        (opts.reset)(t)
    }
}
pub trait ResetParams<E: Dtype, D: DeviceStorage>: TensorCollection<E, D> {
    fn reset_params(&mut self) {
        self.try_reset_params().unwrap();
    }
    fn try_reset_params(&mut self) -> Result<(), D::Err> {
        Self::iter_tensors(&mut RecursiveWalker {
            m: self,
            f: &mut Resetter,
            path: &mut Vec::new(),
        })
    }
}
impl<E: Dtype, D: DeviceStorage, M: TensorCollection<E, D>> ResetParams<E, D> for M {}

So if you wanted to do a custom initialization you could copy the above, and then change the visit method do something like:

struct MyCustomInit;
impl<E: Dtype, D: DeviceStorage> TensorVisitor<E, D> for MyCustomInit {
    type Viewer = ViewTensorMut;
    type Err = D::Err;

    fn visit<S: Shape>(
        &mut self,
        path: String,
        opts: TensorOptions<S, E, D>,
        t: &mut Tensor<S, E, D>,
    ) -> Result<(), D::Err> {
        if S::NUM_DIMS == 2 { ... }
        else if S::NUM_DIMS == 4 { ... }
       else if path.contains("weight") { ... }
        else { (opts.reset)(t)
    }
}

However I'm not sure we have the correct pub exports for this to work properly, and we should add an example of how to do this before closing this ticket.