cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Instances of the same tx.Module have different tree_structure #78

Open jiyuuchc opened 1 year ago

jiyuuchc commented 1 year ago

Bug:

Instances of the same tx.Module have different tree_structure

class T(tx.Module):
  pass

t1 = T()
t2 = T()

jax.tree_structure(t1) == jax.tree_structure(t2)

>>> False

Patch is simple:

diff --git a/treex/module.py b/treex/module.py
index bafcdc9..fb036d7 100644
--- a/treex/module.py
+++ b/treex/module.py
@@ -76,6 +76,7 @@ class Module(Treex, Filters, metaclass=ModuleMeta):
     _training: bool = to.static(True)
     _initialized: bool = to.static(False)
     _frozen: bool = to.static(False)
+    name: str = to.static(False)

     def __init__(self, name: tp.Optional[str] = None):
         self.name = (
jiyuuchc commented 1 year ago

Looking a bit further. It seems the bug is actually in Treeo.

The problem is Treeo try to auto-annotate, which create instance-level field_metadata, which are all different for different instances. This causes problems in JIT because passing new instances to a function triggers unnecessary recompiling of the function.

I see two ways to solve the problem:

(1) Do not perform auto-annotation after init call. Instead the user should explicit invoke auto-annotation, with the understanding that the call alters tree_structure

or

(2) Manually annotation all tx classes to ensure a class-level metadata already exist for fields created in init

jiyuuchc commented 1 year ago

After some more testing, I found that adding a eq() method to class treeo.types.FieldMetadata solves all the issues.

This seems to be the cleanest solution.

cgarciae commented 1 year ago

Hey @jiyuuchc!

Not actively maintaining Treex/Treeo but happy to approve a PR + create a new release if you want to send it.