google / objax

Apache License 2.0
773 stars 77 forks source link

ResNetV2 from model.zoo does not specify train arg for ResNetV2Block #240

Closed a1302z closed 2 years ago

a1302z commented 2 years ago

Hi, I tried to use resnet_v2.ResNet18, but get an error when using this model, as the ResNetV2Block expects a train arg in the call function, which is not provided by the implementation. As a minimal example:

from objax.zoo import resnet_v2

fake_data = np.random.randn(2, 3, 224, 224)
model = resnet_v2.ResNet18(
    in_channels=3,
    num_classes=1000,
)
model(fake_data)

This produces the following error:

Traceback (most recent call last):
  File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 488, in run_layer
    return f(*args, **util.local_kwargs(kwargs, f))
TypeError: ResNetV2Block.__call__() missing 1 required positional argument: 'training'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 488, in run_layer
    return f(*args, **util.local_kwargs(kwargs, f))
  File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 497, in __call__
    args = self.run_layer(i, f, args, kwargs)
  File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 490, in run_layer
    raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e
TypeError: Sequential layer[0] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92c4a38310> ResNetV2Block.__call__() missing 1 required positional argument: 'training'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/a/O/resnet_minimal_example.py", line 9, in <module>
    model(fake_data)
  File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 497, in __call__
    args = self.run_layer(i, f, args, kwargs)
  File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 490, in run_layer
    raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e
TypeError: Sequential layer[3] objax.zoo.resnet_v2.ResNetV2BlockGroup(
  [0] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92c4a38310>
  [1] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92b3fc7250>
) Sequential layer[0] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92c4a38310> ResNetV2Block.__call__() missing 1 required positional argument: 'training'

Am I missing something here? If not I think one way to solve this would be to introduce a train() and eval() function, as in pytorch, which sets the training value as instance variable. I am happy to produce a pull request, if you think this is a good option.

aterzis-google commented 2 years ago

Alex,

Thank you for your interest in Objax!

You need an additional model(.., trainining=True/False) parameter. See the example in:

https://github.com/google/objax/blob/master/examples/image_classification/imagenet_resnet50_train.py

Hope this helps!

a1302z commented 2 years ago

Thanks, I just figured it out. Sorry for the distraction and thanks for the answer :)