Closed a1302z closed 2 years ago
Alex,
Thank you for your interest in Objax!
You need an additional model(.., trainining=True/False) parameter. See the example in:
https://github.com/google/objax/blob/master/examples/image_classification/imagenet_resnet50_train.py
Hope this helps!
Thanks, I just figured it out. Sorry for the distraction and thanks for the answer :)
Hi, I tried to use resnet_v2.ResNet18, but get an error when using this model, as the ResNetV2Block expects a train arg in the call function, which is not provided by the implementation. As a minimal example:
This produces the following error:
Am I missing something here? If not I think one way to solve this would be to introduce a train() and eval() function, as in pytorch, which sets the training value as instance variable. I am happy to produce a pull request, if you think this is a good option.