cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
214 stars 17 forks source link

Add LayerNorm and GroupNorm #58

Closed lkhphuc closed 2 years ago

lkhphuc commented 2 years ago

Add LayerNorm and GroupNorm, also move rename BatchNorm file to Norm.

Unrelated question: I saw in the Treex's module, you defined module's properties as both class variables (dataclass style) and as init's parameters? What is the reason behind this? I though the point of dataclass-like attribute is to reduce the boilerplate in __init__?

cgarciae commented 2 years ago

Hey @lkhphuc, thanks a lot! I'll give it a review soon.

Meanwhile I'll answer your question:

Unrelated question: I saw in the Treex's module, you defined module's properties as both class variables (dataclass style) and as init's parameters? What is the reason behind this? I though the point of dataclass-like attribute is to reduce the boilerplate in init?

There are 2 reasons I think, the first one is that using class variables is a good way to provide metadata for fields regardless if its a dataclass or not. The second one is that if you do it like this you immediately have support for dataclasses and non-dataclasses which is ideal because you want to have flexibility, Treex (or rather treeo) supports both.

cgarciae commented 2 years ago

@lkhphuc reviewed it, I have no comments, great job! 🎉