Open hmf opened 1 year ago
I am trying to re-implement the following Python function that initializes the values of a module's weights and biases:
# better init, not covered in the original GPT video, but important, will cover in followup video
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
After adding some additional init
function to Storch, I coded the following function:
private def init_weights[D <: FloatNN | ComplexNN](m: Module with HasWeight[D]): Unit =
m match
case lm : nn.Linear[_] =>
torch.nn.init.normal_(lm.weight, mean=0.0, std=0.02)
if true // lm.options.bias()
then
torch.nn.init.zeros_(lm.bias)
case _ : nn.Embedding[_] =>
???
case _ => ???
???
The first thing to note is that Module
does not have a weight
member so I had to use HasWeight[D]
. The HasWeight[D]
does not, unlike other traits in Module
extend nn.Module
.
The second thing of note is that we don't have a (adapted from HasWeight[D]
):
trait HasBias[ParamType <: FloatNN | ComplexNN]:
def bias: Tensor[ParamType]
The issue I now have is to find a way to test if the Module
has bias. The nn.Linear
, for example, has LinearOptions
that I could use, but it is private. I assume the objective is to keep this hidden to maintain an idiomatic Scala API. Moreover, not all modules will have options that include bias (for example Embedding
).
The simplest solution is to have a hasBias(): Boolean
method. The Module
trait could have a default implementation that returns false. Any class that could have bias would have to override this method and access the options to return Boolean value.
Alternatively one could add a HasBias
trait with the hasBias(): Boolean
method. In this case overriding the method to return true may not be safe (depends on the order in which a class/trait is extended?)
Finally, we could try something fancy with type parameters so that bias existence is known at compile time, but I am uncertain of this.
Any suggestions on how I should proceed?
TIA
Sorry @hmf missed that somehow. I'd suggest we start with the simplest option, adding hasBias(): Boolean
to Module
.
Since enabling/disabling bias is often a constructor parameter, I think it is harder to type compared to HasWeights
. We can still improve later if we see that it makes sense.
Add a weight and bias initialization method to the
nn.Module
so we can set these values via anapply
method like PyTorch that does this.Reference to Python documentation here. Code here.
This code is required to complete issue #51.