sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
116 stars 7 forks source link

Add Python-like apply method to Module to initialize weights and biases #61

Open hmf opened 1 year ago

hmf commented 1 year ago

Add a weight and bias initialization method to the nn.Module so we can set these values via an apply method like PyTorch that does this.

Reference to Python documentation here. Code here.

This code is required to complete issue #51.

hmf commented 1 year ago

I am trying to re-implement the following Python function that initializes the values of a module's weights and biases:


    # better init, not covered in the original GPT video, but important, will cover in followup video
    self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

After adding some additional init function to Storch, I coded the following function:

    private def init_weights[D <: FloatNN | ComplexNN](m: Module with HasWeight[D]): Unit = 
      m match
        case lm : nn.Linear[_] => 
          torch.nn.init.normal_(lm.weight, mean=0.0, std=0.02)
          if true // lm.options.bias()
          then
            torch.nn.init.zeros_(lm.bias)
        case _ : nn.Embedding[_] => 
          ???
        case _ => ???
      ???

The first thing to note is that Moduledoes not have a weightmember so I had to use HasWeight[D]. The HasWeight[D] does not, unlike other traits in Module extend nn.Module.

The second thing of note is that we don't have a (adapted from HasWeight[D]):

trait HasBias[ParamType <: FloatNN | ComplexNN]:
  def bias: Tensor[ParamType]

The issue I now have is to find a way to test if the Module has bias. The nn.Linear, for example, has LinearOptions that I could use, but it is private. I assume the objective is to keep this hidden to maintain an idiomatic Scala API. Moreover, not all modules will have options that include bias (for example Embedding).

The simplest solution is to have a hasBias(): Boolean method. The Module trait could have a default implementation that returns false. Any class that could have bias would have to override this method and access the options to return Boolean value.

Alternatively one could add a HasBias trait with the hasBias(): Boolean method. In this case overriding the method to return true may not be safe (depends on the order in which a class/trait is extended?)

Finally, we could try something fancy with type parameters so that bias existence is known at compile time, but I am uncertain of this.

Any suggestions on how I should proceed?

TIA

sbrunk commented 1 year ago

Sorry @hmf missed that somehow. I'd suggest we start with the simplest option, adding hasBias(): Boolean to Module.

Since enabling/disabling bias is often a constructor parameter, I think it is harder to type compared to HasWeights. We can still improve later if we see that it makes sense.