google / ml_collections

ML Collections is a library of Python Collections designed for ML use cases.
https://ml-collections.readthedocs.io/
Apache License 2.0
898 stars 42 forks source link

FieldReference op assumes type does not change when it does #18

Open ciupakabra opened 2 years ago

ciupakabra commented 2 years ago

Hi, I am trying to do some more complicated configuration file setups which include putting together a couple of lazy ints to a lazy tuple. The following script is an example how I can get it to work. However the resulting FieldReference object stores a tuple but thinks it's storing an int.

from ml_collections import ConfigDict, FieldReference
from ml_collections.config_dict import _Op

if __name__ == "__main__":
    a = FieldReference(None, int)
    b = FieldReference(None, int)

    a_tuple = FieldReference(a, op=_Op(lambda x : (x,), ()))
    b_tuple = FieldReference(b, op=_Op(lambda x : (x,), ()))
    c = a_tuple + b_tuple

    a.set(1)
    b.set(2)

    print(f"a_tuple: {a_tuple.get()}")
    print(f"b_tuple: {b_tuple.get()}")
    print(f"c: {c.get()}")

    print(f"a_tuple type: {a_tuple._field_type}")
    print(f"b_tuple type: {b_tuple._field_type}")
    print(f"c type: {c._field_type}")

This outputs

a_tuple: (1,)
b_tuple: (2,)
c: (1, 2)
a_tuple type: <class 'int'>
b_tuple type: <class 'int'>
c type: <class 'int'>

Which is probably a bug. This was run on python 3.10.8 and ml_collections 0.1.1.

Note that if we change the script to

from ml_collections import ConfigDict, FieldReference
from ml_collections.config_dict import _Op

if __name__ == "__main__":
    a = FieldReference(None, int)
    b = FieldReference(None, int)

    a_tuple = FieldReference(a, tuple, op=_Op(lambda x : (x,), ()))
    b_tuple = FieldReference(b, tuple, op=_Op(lambda x : (x,), ()))
    c = a_tuple + b_tuple

    a.set(1)
    b.set(2)

    print(f"a_tuple: {a_tuple.get()}")
    print(f"b_tuple: {b_tuple.get()}")
    print(f"c: {c.get()}")

    print(f"a_tuple type: {a_tuple._field_type}")
    print(f"b_tuple type: {b_tuple._field_type}")
    print(f"c type: {c._field_type}")

it throws an exception:

Traceback (most recent call last):
  File "/home/andrius/repos/sde-sampling/tmp.py", line 8, in <module>
    a_tuple = FieldReference(a, tuple, op=_Op(lambda x : (x,), ()))
  File "/home/andrius/env/lib/python3.10/site-packages/ml_collections/config_dict/config_dict.py", line 248, in __init__
    self.set(default)
  File "/home/andrius/env/lib/python3.10/site-packages/ml_collections/config_dict/config_dict.py", line 305, in set
    raise TypeError('Reference is of type {} but should be of type {}'
TypeError: Reference is of type <class 'int'> but should be of type <class 'tuple'>

I think it would be better to change the behaviour as follows:

I'd be happy to file a PR if authors agree.