google / objax

Apache License 2.0
771 stars 77 forks source link

More control over var/module namespace. #42

Closed rwightman closed 4 years ago

rwightman commented 4 years ago

I got my first 'hello world' model experiment working w/ Objax. I adapted my PyTorch EfficientNet impl. Overall pretty smooth, currently wrapping Conv2d so I can get the padding I want.

One thing that stuck out after inspecting the model, the var namespace is a mess. An aspect of modelling that I value highly is the ability to have sensible checkpoint/var maps to work with. I often end up dealing with conversions between frameworks, exports for mobile or embedded targets and having your vars (parameters) sensibly named, and often being able to control those names in the originating framework is important.

Any thoughts on improving this? The current name/scoping mechanism forces the inclusion of the Module class names, is that necessary? Shouldn't attr names through the tree be enough for uniqueness?

Also, there is no ability to specify names for modules in sequential containers. I use this quite often for frameworks that have it. Sometimes I don't care much (long list of block repeats, 0..n is fine), but for finer grained blocks I like to know what conv is what by looking at the var names. '0.b, o.w' etc isn't very useful.

I'll post an example of the var keys below, and comparison point for pytorch.

rwightman commented 4 years ago

Objax

'(EfficientNet).stem(ConvBnAct).conv(Conv2d).w',
 '(EfficientNet).stem(ConvBnAct).bn(BatchNorm2D).running_mean',
 '(EfficientNet).stem(ConvBnAct).bn(BatchNorm2D).running_var',
 '(EfficientNet).stem(ConvBnAct).bn(BatchNorm2D).beta',
 '(EfficientNet).stem(ConvBnAct).bn(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[1](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[2](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[3](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[4](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[1](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[2](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[5](Sequential)[3](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).conv_exp(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_exp(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).conv_dw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_dw(BatchNorm2D).gamma',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc1(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).b',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).se(SqueezeExcite).fc2(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).conv_pw(Conv2d).w',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_mean',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).running_var',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta',
 '(EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma',
 '(EfficientNet).head(Head).conv_1x1(Conv2d).w',
 '(EfficientNet).head(Head).bn(BatchNorm2D).running_mean',
 '(EfficientNet).head(Head).bn(BatchNorm2D).running_var',
 '(EfficientNet).head(Head).bn(BatchNorm2D).beta',
 '(EfficientNet).head(Head).bn(BatchNorm2D).gamma',
 '(EfficientNet).head(Head).classifier(Linear).b',
 '(EfficientNet).head(Head).classifier(Linear).w'
rwightman commented 4 years ago

PyTorch (I have slightly re-orged the Objax version with head and stem extracted to their own modules and some attr naming improved)

conv_stem.weight
bn1.weight
bn1.bias
blocks.0.0.conv_dw.weight
blocks.0.0.bn1.weight
blocks.0.0.bn1.bias
blocks.0.0.se.conv_reduce.weight
blocks.0.0.se.conv_reduce.bias
blocks.0.0.se.conv_expand.weight
blocks.0.0.se.conv_expand.bias
blocks.0.0.conv_pw.weight
blocks.0.0.bn2.weight
blocks.0.0.bn2.bias
blocks.1.0.conv_pw.weight
blocks.1.0.bn1.weight
blocks.1.0.bn1.bias
blocks.1.0.conv_dw.weight
blocks.1.0.bn2.weight
blocks.1.0.bn2.bias
blocks.1.0.se.conv_reduce.weight
blocks.1.0.se.conv_reduce.bias
blocks.1.0.se.conv_expand.weight
blocks.1.0.se.conv_expand.bias
blocks.1.0.conv_pwl.weight
blocks.1.0.bn3.weight
blocks.1.0.bn3.bias
blocks.1.1.conv_pw.weight
blocks.1.1.bn1.weight
blocks.1.1.bn1.bias
blocks.1.1.conv_dw.weight
blocks.1.1.bn2.weight
blocks.1.1.bn2.bias
blocks.1.1.se.conv_reduce.weight
blocks.1.1.se.conv_reduce.bias
blocks.1.1.se.conv_expand.weight
blocks.1.1.se.conv_expand.bias
blocks.1.1.conv_pwl.weight
blocks.1.1.bn3.weight
blocks.1.1.bn3.bias
blocks.2.0.conv_pw.weight
blocks.2.0.bn1.weight
blocks.2.0.bn1.bias
blocks.2.0.conv_dw.weight
blocks.2.0.bn2.weight
blocks.2.0.bn2.bias
blocks.2.0.se.conv_reduce.weight
blocks.2.0.se.conv_reduce.bias
blocks.2.0.se.conv_expand.weight
blocks.2.0.se.conv_expand.bias
blocks.2.0.conv_pwl.weight
blocks.2.0.bn3.weight
blocks.2.0.bn3.bias
blocks.2.1.conv_pw.weight
blocks.2.1.bn1.weight
blocks.2.1.bn1.bias
blocks.2.1.conv_dw.weight
blocks.2.1.bn2.weight
blocks.2.1.bn2.bias
blocks.2.1.se.conv_reduce.weight
blocks.2.1.se.conv_reduce.bias
blocks.2.1.se.conv_expand.weight
blocks.2.1.se.conv_expand.bias
blocks.2.1.conv_pwl.weight
blocks.2.1.bn3.weight
blocks.2.1.bn3.bias
blocks.3.0.conv_pw.weight
blocks.3.0.bn1.weight
blocks.3.0.bn1.bias
blocks.3.0.conv_dw.weight
blocks.3.0.bn2.weight
blocks.3.0.bn2.bias
blocks.3.0.se.conv_reduce.weight
blocks.3.0.se.conv_reduce.bias
blocks.3.0.se.conv_expand.weight
blocks.3.0.se.conv_expand.bias
blocks.3.0.conv_pwl.weight
blocks.3.0.bn3.weight
blocks.3.0.bn3.bias
blocks.3.1.conv_pw.weight
blocks.3.1.bn1.weight
blocks.3.1.bn1.bias
blocks.3.1.conv_dw.weight
blocks.3.1.bn2.weight
blocks.3.1.bn2.bias
blocks.3.1.se.conv_reduce.weight
blocks.3.1.se.conv_reduce.bias
blocks.3.1.se.conv_expand.weight
blocks.3.1.se.conv_expand.bias
blocks.3.1.conv_pwl.weight
blocks.3.1.bn3.weight
blocks.3.1.bn3.bias
blocks.3.2.conv_pw.weight
blocks.3.2.bn1.weight
blocks.3.2.bn1.bias
blocks.3.2.conv_dw.weight
blocks.3.2.bn2.weight
blocks.3.2.bn2.bias
blocks.3.2.se.conv_reduce.weight
blocks.3.2.se.conv_reduce.bias
blocks.3.2.se.conv_expand.weight
blocks.3.2.se.conv_expand.bias
blocks.3.2.conv_pwl.weight
blocks.3.2.bn3.weight
blocks.3.2.bn3.bias
blocks.4.0.conv_pw.weight
blocks.4.0.bn1.weight
blocks.4.0.bn1.bias
blocks.4.0.conv_dw.weight
blocks.4.0.bn2.weight
blocks.4.0.bn2.bias
blocks.4.0.se.conv_reduce.weight
blocks.4.0.se.conv_reduce.bias
blocks.4.0.se.conv_expand.weight
blocks.4.0.se.conv_expand.bias
blocks.4.0.conv_pwl.weight
blocks.4.0.bn3.weight
blocks.4.0.bn3.bias
blocks.4.1.conv_pw.weight
blocks.4.1.bn1.weight
blocks.4.1.bn1.bias
blocks.4.1.conv_dw.weight
blocks.4.1.bn2.weight
blocks.4.1.bn2.bias
blocks.4.1.se.conv_reduce.weight
blocks.4.1.se.conv_reduce.bias
blocks.4.1.se.conv_expand.weight
blocks.4.1.se.conv_expand.bias
blocks.4.1.conv_pwl.weight
blocks.4.1.bn3.weight
blocks.4.1.bn3.bias
blocks.4.2.conv_pw.weight
blocks.4.2.bn1.weight
blocks.4.2.bn1.bias
blocks.4.2.conv_dw.weight
blocks.4.2.bn2.weight
blocks.4.2.bn2.bias
blocks.4.2.se.conv_reduce.weight
blocks.4.2.se.conv_reduce.bias
blocks.4.2.se.conv_expand.weight
blocks.4.2.se.conv_expand.bias
blocks.4.2.conv_pwl.weight
blocks.4.2.bn3.weight
blocks.4.2.bn3.bias
blocks.5.0.conv_pw.weight
blocks.5.0.bn1.weight
blocks.5.0.bn1.bias
blocks.5.0.conv_dw.weight
blocks.5.0.bn2.weight
blocks.5.0.bn2.bias
blocks.5.0.se.conv_reduce.weight
blocks.5.0.se.conv_reduce.bias
blocks.5.0.se.conv_expand.weight
blocks.5.0.se.conv_expand.bias
blocks.5.0.conv_pwl.weight
blocks.5.0.bn3.weight
blocks.5.0.bn3.bias
blocks.5.1.conv_pw.weight
blocks.5.1.bn1.weight
blocks.5.1.bn1.bias
blocks.5.1.conv_dw.weight
blocks.5.1.bn2.weight
blocks.5.1.bn2.bias
blocks.5.1.se.conv_reduce.weight
blocks.5.1.se.conv_reduce.bias
blocks.5.1.se.conv_expand.weight
blocks.5.1.se.conv_expand.bias
blocks.5.1.conv_pwl.weight
blocks.5.1.bn3.weight
blocks.5.1.bn3.bias
blocks.5.2.conv_pw.weight
blocks.5.2.bn1.weight
blocks.5.2.bn1.bias
blocks.5.2.conv_dw.weight
blocks.5.2.bn2.weight
blocks.5.2.bn2.bias
blocks.5.2.se.conv_reduce.weight
blocks.5.2.se.conv_reduce.bias
blocks.5.2.se.conv_expand.weight
blocks.5.2.se.conv_expand.bias
blocks.5.2.conv_pwl.weight
blocks.5.2.bn3.weight
blocks.5.2.bn3.bias
blocks.5.3.conv_pw.weight
blocks.5.3.bn1.weight
blocks.5.3.bn1.bias
blocks.5.3.conv_dw.weight
blocks.5.3.bn2.weight
blocks.5.3.bn2.bias
blocks.5.3.se.conv_reduce.weight
blocks.5.3.se.conv_reduce.bias
blocks.5.3.se.conv_expand.weight
blocks.5.3.se.conv_expand.bias
blocks.5.3.conv_pwl.weight
blocks.5.3.bn3.weight
blocks.5.3.bn3.bias
blocks.6.0.conv_pw.weight
blocks.6.0.bn1.weight
blocks.6.0.bn1.bias
blocks.6.0.conv_dw.weight
blocks.6.0.bn2.weight
blocks.6.0.bn2.bias
blocks.6.0.se.conv_reduce.weight
blocks.6.0.se.conv_reduce.bias
blocks.6.0.se.conv_expand.weight
blocks.6.0.se.conv_expand.bias
blocks.6.0.conv_pwl.weight
blocks.6.0.bn3.weight
blocks.6.0.bn3.bias
conv_head.weight
bn2.weight
bn2.bias
classifier.weight
classifier.bias
david-berthelot commented 4 years ago

Right now it's hardcoded, I chose the current format to maximize the ease to debug.

It seems feasible to make it customizable, for example having some FORMAT string in the module, modulelist etc... It would be something like what we did for custom checkpoint naming: https://github.com/google/objax/blob/master/objax/io/checkpoint.py#L41

@AlexeyKurakin What do you think?

AlexeyKurakin commented 4 years ago

@rwightman , you are right that class name is redundant and all sub-modules are uniquely identified by the name of the attribute.

Overall I think we can change the naming format for the variables.

@david-berthelot my understanding is that there are several things asked, and we should decide which of them and how to implement. Here is a list:

  1. Control on specific printing format for each module (whether to include class name, put extra parenthesis, etc...). This could be done with static attribute FORMAT in the module which user may change as needed. My concern about allowing custom formatting are following:

    • we need to make syntax of formatting string flexible enough to allow what various users want to do and we need to write extra code to perform this formatting.
    • how does formatting spec is changed? if it's a static attribute of the class then essentially it's a global variable which need to be set/changed in the beginning of the program. If user forget to change formatting string then checkpoint will not load. For example if user A changed formatting, then trained resnet50 model and saved checkpoint, then user A shared checkpoint with user B but forget to tell about custom formatting then user B won't be able to load checkpoint.
  2. Instead of doing 1, just remove class name (and parenthesis) from the variable name. So (EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).beta becomes blocks[6][0].bn_pw.beta. I'm not sure how useful are class names for debugging, but they definitely not needed for uniqueness of identifiers.

  3. Allow override of indices of sequential module with some meaningful names. This could be done by adding list with names of the elements to Sequential

david-berthelot commented 4 years ago

The way I see is that class name is very informative for people who didn't write the model. And I find it informative for myself too.

Consequently the only way to make everyone happy is to let the naming be customizable.

I don't think making the syntax sufficiently flexible is too much of a challenge, the core question is whether we want to do it. Indeed, as you pointed out, saving and loading is going to be a problem since names are used to reference variables and if they get changed from the default, it won't be easy to load a checkpoint saved by a program that modified the variable names without knowing how they were modified. Now how common is this situation going to be?

As to (3), if one doesn't want indexes, they shouldn't use a list or a sequential module. Overriding names in lists would add more confusion than it would save.

rwightman commented 4 years ago

Something to point out wrt to having class names in the var names, it ties your checkpoints to the name of the classes. Deciding that I want to refactor and change the name of my model class or any other block/layer abstractions shouldn't change checkpoint compat if the layer definitions and attr names remain the same. This matters when you have many users of a framework, pretrained model zoos forming around an ecosystem, etc.

For the sequential with names, it's used quite a bit in PyTorch. The default is sequential numbers starting at 0 like here. It does let you skip writing the __call___ method for blocks while retaining identifiable names. Another less common reason to like the ability for fixing names in Sequential, you can add/remove layers without trainable params without breaking checkpoints.

It could be two lists, or OrderedDict, list of tuples, etc. Sequential() also has an .add_module() fn that takes a name arg.

It's not uncommon to see code like

        self.stack = nn.Sequential()
        for i in range(3):
            if start_with_relu:
                self.stack.add_module(f'act{i}', nn.ReLU(inplace=i > 0))
            self.stack.add_module(f'conv{i}', SeparableConv2d(
                in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
                act_layer=separable_act_layer, norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            in_chs = out_chs[i]

or

            self.features = nn.Sequential(OrderedDict([
                ('conv1', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
                ('norm1', norm_layer(stem_chs_1)),
                ('conv2', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
                ('norm3', norm_layer(stem_chs_2)),
                ('conv3', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
                ('norm3', norm_layer(num_init_features)),
                ('pool', stem_pool),
            ]))
david-berthelot commented 4 years ago

I sort of like to associate the variable name with class and attribute, since they typically don't change (like Conv2D.w for example). If one refactors something, there's a chance that it's not doing the same thing, so having matching names would lead to silent errors.

Also note, it's very easy to copy a VarCollection to a differently named model by doing target.vars().assign(source.vars().tensors()) since these assignments are order based.

Now if someone reorders fields and renames them, then old checkpoints will have to be matched by name using some conversion rule but that's probably the safest thing at this point.

I precisely wanted to avoid string-based naming since you never know where a variable is coming from and I had too much headaches debugging third party code in the past due to that (in other frameworks).

AlexeyKurakin commented 4 years ago

I’ll try to reconcile all options which were discussed, along with their pros and cons.

Also when thinking about these options keep in mind the following:

Sequential module with names (in addition to indices)

  1. Add names to Sequential.

    • Advantage: This will let us write PyTorch-style code where elements of Sequential have assigned names.
    • Disadvantage: This complicates the code of Sequential, and generally speaking users can just make a subclass of Module with named sub-modules.
  2. Keep Sequential as is.

    • Opposite to previous option. It keeps sequential simple, but to make a sequential module with named elements user have to manually write a new subclass of module.
  3. Add another class NamedSequential, which would sequential container with

    • This keeps code of Sequential simple, but it still adds a new class which has additional complexity of named sequential collection.

Formatting of variables

  1. Keep formatting as is (i.e. name of the attribute and name of the class)
    • Advantage: Simplifies debugging, especially in the case when the user need to look at checkpoint written by another person.
    • Disadvantage: Class name in the var name lead to pretty long overall names
    • Disadvantage: Fragile if class name changes, which could break checkpoints.
    • Disadvantage: When converting a checkpoint from another framework, it’s harder to map names of the variables.

Note that it’s unlikely that the name of the module class will change for basic building blocks like Conv2D. However the name of the module may change for modules which are describing complex pieces of some neural network (for example rename ResnetBlock into ResnetV2Block). Also we can potentially rename the class with the entire neural network (for example rename Resnet50 into ResnetV2_50).

  1. Variable name only includes attribute name without class name

    • Advantage: Easier to map variable names to PyTorch checkpoints.
    • Advantage: Less likely to break when class name changes.
    • Disadvantage: When debugging user have to look into the code to understand the class of each module.
  2. Provide custom formatter for variable names

    • Advantage: Variable name could either include or not include module class names depending on user needs.
    • Disadvantage: extra code complexity to provide support for variable formatting syntax.
    • Disadvantage: checkpoint will become incompatible if formatting changes. So if we provide pre-trained checkpoint, but the user want to use different formatting then the user won’t be able to just load our pre-trained checkpoint.
    • Also if we decide to go this route then we should decide what is the default formatting.
rwightman commented 4 years ago

@AlexeyKurakin thanks for the summary... I was going to implement something like NamedSequential myself in absence of a desire to make changes here. At the Module level though it's hard to suplement with my own changes (for variable naming).

One other thought on the naming. One of the model collections I maintain is a fairly extensive set of image classification models. By itself, not so interesting. The big use case is that they can serve as backbones for a wide variety of other tasks like segmentation, obj detection, etc.

Tying checkpoints to class names, and generally having a less flexible interface for manipulating module hierarchies would make some of the model manipulations I rely on for embedding nets within nets, and flattening, remxiing nets for extracting feature maps from the middle, etc much more challenging.

Also, what about modules like SynchBachNorm, FrozenBatchNorm (one way of freezing a BN layer, override the class), you train with one and init the model for inference with another and it won't load?

david-berthelot commented 4 years ago

@AlexeyKurakin For (2) you forgot to mention that it's error prone, if I replace Conv2D with Conv3D, the variables names are the same but their weights aren't. It will load successfully and introduce a silent error which will be seen only when running the model. This is also an argument in favor of (1) safety which you also omitted.

@rwightman If you're simply looking to map names, the current interface can do it easily. Remember that variables don't have names, they're only given on in a VarCollection (which references variables). Since VarCollection is simply a dictionary you can easily rename variables by making a new VarCollection from it.

Here's an example:

import re
import objax

class MiniNet(objax.Module):
    def __init__(self, m, n, p):
        self.f1 = objax.nn.Linear(m, n)
        self.f2 = objax.nn.Linear(n, p)

    def __call__(self, x):
        y = self.f1(x)
        y = objax.functional.relu(y)  # Apply a non-linearity.
        return self.f2(y)

r = re.compile('\([^)]+\)')
print(m.vars())
# (MiniNet).f1(Linear).b        2 (2,)
# (MiniNet).f1(Linear).w        2 (1, 2)
# (MiniNet).f2(Linear).b        3 (3,)
# (MiniNet).f2(Linear).w        6 (2, 3)
# +Total(4)                    13

print(objax.VarCollection((r.sub('', k), v) for k, v in m.vars().items()))
# .f1.b                       2 (2,)
# .f1.w                       2 (1, 2)
# .f2.b                       3 (3,)
# .f2.w                       6 (2, 3)
# +Total(4)                  13

As to loading a checkpoint referring to a SynchBatchNorm into a FrozenBatchNorm that should not work by default given the safety risk it involves. But again, it's really easy to solve using the fact that VarCollection is really a dictionary.

import objax

class MiniNet(objax.Module):
    def __init__(self, m, n, p):
        self.f1 = objax.nn.Linear(m, n)
        self.f2 = objax.nn.Linear(n, p)

    def __call__(self, x):
        y = self.f1(x)
        y = objax.functional.relu(y)  # Apply a non-linearity.
        return self.f2(y)

class MyLinear(objax.nn.Linear):
    pass

m1 = MiniNet(1, 2, 3)
objax.io.save_var_collection('/tmp/mininet.npz', m1.vars())

# We redefine MiniNet, this time using a different module (MyLinear)
class MiniNet(objax.Module):
    def __init__(self, m, n, p):
        self.f1 = MyLinear(m, n)
        self.f2 = MyLinear(n, p)

    def __call__(self, x):
        y = self.f1(x)
        y = objax.functional.relu(y)  # Apply a non-linearity.
        return self.f2(y)

m2 = MiniNet(1, 2, 3)
var_map = objax.VarCollection((k.replace('MyLinear', 'Linear'), v) for k, v in m2.vars().items())
objax.io.load_var_collection('/tmp/mininet.npz', var_map)
print(m2.vars())

By having to make it explicit that MyLinear should load the weights for Linear, it's more verbose since it's explicit but it's a lot safer too in my opinion.

rwightman commented 4 years ago

@david-berthelot thanks for explanation, one thing that I might be missing though (I haven't gone through the whole train / save / load cyle yet). The .w shapes for say Conv2d vs 3d are different. Wouldn't that error out the moment you try to load and assign to the incompatible vars?

Working with models across diff frameworks and evolving model code without breaking checkpoints 'if the slipper fits' has been pretty robust when you have dozens if not hundreds of layers identified by a combo of attribute name hierarchies (strings) and shapes.

AlexeyKurakin commented 4 years ago

Error-prone

Being error-prone is a valid concern. It not an issue for Conv2d vs Conv3d because kernel shapes are different. With current set of basic ops, I think only variables for Conv2d and Conv2dTransposed could be interchanged and only in the case when nin == nout. Another one is BatchNorm vs SyncedBatchNorm, though one can argue that these two ops should be interchangeable by design.

Nevertheless I could imagine that we add some ops which have very the same shapes of variables but which work differently. For example we may add GroupNorm, LayerNorm and InstanceNorm at some point. I think all of these will have same shape of variables. Then one user may save checkpoint for ResnetV2 model with one type of normalization and another user may try to load it with another type of normalization.

Overall it does not seems like this type of error should happen very often, but it's possible.

Ability to easy read checkpoint with different variable naming convention

Ability to convert checkpoints seems to be the main concern of current way of naming variables in VarCollection. So maybe we can come up with another way of solving this problem, without relying too much on specific way to name variable.

Some ideas which came to my mind:

@david-berthelot The solution which you proposed would work, but it's considering very simple neural network (you need to do only one substitution of values). For this simple example, variable name remapping could be expressed in one line:

var_map = objax.VarCollection((k.replace('MyLinear', 'Linear'), v) for k, v in m2.vars().items())

What if model is more complex and require a lot more substituions? Examples could include: converting Tensorflow ResnetV2 checkpoint to Objax or converting EfficientNet mentioned above. What would be a good way for user to write more complex variable remapping?

david-berthelot commented 4 years ago

In the end, VarCollection is simply a dictionary. So if we were to provide some renaming utility, for the sake of generality it should simply take a dict mapping (old_name, new_name) and work on dictionary keys since there's nothing specific about VarCollection in the renaming task. Following on this, I don't think we should touch the loading/saving API which simply transfers a dict to/from disk. I'm okay to extend VarCollection with a renaming method in order to make it user friendly. Basically, we'd want it to take something like an Iterable[Tuple[Union[str,regex], str]] to allow many renamings at once.

Another point I wanted to make about naming is that it's easier to remove text than add it. Like in the first example, I provided to remove entirely the module name. The opposite is hard, if the class name is missing, there's no way it can be added.

I think we've talked a lot about "what if" and abstract situations. Let's have more hard examples to help us ground our thinking. The first PyTorch example earlier appears to be handled relatively well with my first example (replace (.*) with ''), or is it?

AlexeyKurakin commented 4 years ago

I think it would be better if remapping function takes a callable which maps string into string:

# in Objax code
class VarCollection(Dict[str, BaseVar]):
    # ...
    def rename_vars(self, remap_name: Callable[[str], str]) -> VarCollection:
        return VarCollection({remap_name(k): v for k, v in self.iteritems()})

# in user code
find_class_re = re.compile('\([^)]+\)')

def remap_name(name: str) -> str:
    name = find_class_re.sub('', name)
    name = name[1:]
    name = name.replace(']', '')
    name = name.replace('[', '.')
    return name

var_map = model.vars()
var_map = vc.rename_vars(remap_name)
objax.io.load_var_collection('converted_pytorch_ckpt.npz', var_map)

In addition to the code above we may provide version of remap_name which takes Iterable[Tuple[Union[str,regex], str]] and would cover most of common use cases:

# additional piece in Objax code
def rename_with_regex(name: str, substitutions: Iterable[Tuple[Union[str,regex], str]]) -> str:
    # ...

var_map = model.vars()
# the following line probably won't wont, it's just for illustration
var_map = vc.rename_vars(lambda x: objax.utils.rename_with_regex(x, [('\([^)]+\)', ''), ('\[', '.'). ('\]', ''), ('^\.', '')]))
objax.io.load_var_collection('converted_pytorch_ckpt.npz', var_map)

Regarding the code example which replaces (.*) with empty string. It will convert (EfficientNet).blocks(Sequential)[6](Sequential)[0](InvertedResidual).bn_pw(BatchNorm2D).gamma into .blocks[6][0].bn_pw.gamma while equivalent variable from PyTorch checkpoint is blocks.6.0.bn_pw.gamma. Technically it's just adding two or tree additional regex replacements to convert Objax name into PyTorch name. But I think as soon there are several substitutions involved it could not be done nicely with one line of code anymore, that's why it would be useful to have helper method.

AlexeyKurakin commented 4 years ago

@rwightman Could you comment whether any of the described ideas would work well for your use case of converting checkpoints?

rwightman commented 4 years ago

@AlexeyKurakin I'm not sure there is a need to define formal interfaces for this, most of the work is in defining the mappings, it's easy enough to create a fn to apply the mapping to the vars dict after its loaded. This can all work, the main point was this approach feels like it adds more friction for benefits that I've never needed.

For pretrained zoo model loading, I usually have a remap/filter fn argument (a level above the IO loader, before mapping to the model). This is used for potentially significant changes that may have happened in the years a pretrained model weight is avail to dl over many versions of the underlying model code, forked, and forked, etc. It possibly goes beyond simple string mapping though so I wouldn't bake that limitation into an interface, just vc -> vc. Sometimes you may want to convert a Linear layer to a 1x1 conv or make other changes that go beyond name mapping.

The biggest issue I have with the class name approach is that it may require multiple regex string mappings for simply loading checkpoints between training and prediction in a standard workflow. Could also forsee it adding extra boilerplate for loading checkpoints for a variant of the model setup for extracting feature maps, etc as already mentioned.

What about having a flag (no_type/no_type_check) that doesn't include the class name when the vars are collected, and on restore, doesn't fire an error (maybe warns?). It would give a little more flexiblity for those that wanted it, but not completely remove all name checks as in the assign by order alternative.

For typical norm layers default; batch, layer, and group should all have distinct signatures if you combine shape of affine + presence of running vars. They can be setup in a manner that is ambiguous.

AlexeyKurakin commented 4 years ago

It sounds like the main inconvenience of current approach is the need to write boilerplate code to do remapping of class names during restoring model from checkpoint.

I would be reluctant to add no_type flag which would control whether class name is generated when module is created. It's essentially just limited version of custom formatter for variable names and I mentioned above the reasons why I'm concerned about custom fomatter.

On the other hand no_type_check could be achieved with current code by modifying the checkpoint load function. Load function is controlled by Checkpoint.LOAD_FN, see https://github.com/google/objax/blob/d2c8b98b86faac47e19db7a9c2a98342a7d568b3/objax/io/checkpoint.py#L44 and it could be reassigned to any custom loader.

@rwightman for your use case you can have a loaded which would normalize names of variables when finding variable from var collection which is equivalent to variable from checkpoint. Code could look like following:

FIND_CLASS_RE = re.compile('\([^)]+\)')

def remove_class_from_var_name(name: str) -> str:
    return FIND_CLASS_RE.sub('', name)

def custom_load_var_collection(file: Union[str, IO[BinaryIO]],
                               vc: VarCollection,
                               name_norm_fn: Callable[[str], str] = None):
    # this is a modified copy of objax.io.ops.load_var_collection
    # there are only two changes compared to original load_var_collection
    # these two changed just adding call to name_norm_fn on names of var from checkpoint and var collection
    do_close = isinstance(file, str)
    if do_close:
        file = open(file, 'rb')
    data = np.load(file, allow_pickle=False)
    name_index = {name_norm_fn(k): str(i) for i, k in enumerate(data['names'])}   # <-- added name_norm_fn
    name_vars = collections.defaultdict(list)
    for k, v in vc.items():
        if isinstance(v, TrainRef):
            v = v.ref
        name_vars[v].append(name_norm_fn(k))   # <-- added name_norm_fn
    for v, names in name_vars.items():
        for name in names:
            index = name_index.get(name)
            if index is not None:
                v.assign(jn.array(data[index]))
                break
        else:
            raise ValueError(f'Missing value for variables {names}')
    if do_close:
        file.close()

# ...
# in the beginning of the program
objax.io.Checkpoint.LOAD_FN = lambda f, vc: custom_load_var_collection(f, vc, remove_class_from_var_name)

To avoid repeating the code, remove_class_from_var_name and custom_load_var_collection could be put into some shared library for all your models.

@david-berthelot while you mentioned that you would prefer to avoid modifying IO API, what do you think about specific code custom_load_var_collection from above? Does it make sense to extend current load_var_collection to be something like this?

david-berthelot commented 4 years ago

How about reusing load_var_collection like this:

RE_CLASS = re.compile('\([^)]+\)')

def remove_class_from_var_name(name: str) -> str:
    return RE_CLASS.sub('', name)

def custom_load_var_collection(file: Union[str, IO[BinaryIO]],
                               vc: VarCollection,
                               name_fn: Callable[[str], str] = None):
    return objax.io.load_var_collection(file, vc.rename(name_fn))

objax.io.Checkpoint.LOAD_FN = functools.partial(custom_load_var_collection, name_fn=remove_class_from_var_name)
AlexeyKurakin commented 4 years ago

Yes, this will work if checkpoint is saved without class names.

I wrote example above thinking about situation when checkpoint is already saved with names of the classes, and names of the classes in checkpoint do not match to names of the classes in the model. I think this is one of use cases which @rwightman described in latest post.

In such situation loading of the checkpoint could be achieved if loader removes class names from both var names in VarCollection and var names in checkpoint [what is done in my example]. Slightly more generic way to handle it is to provide two name_fn to load_var_collection - one for renaming variables in checkpoint and another for renaming variables in VarCollection.

If user can/willing to save new checkpoint with updated names of the variables then @david-berthelot example works fine.

AlexeyKurakin commented 4 years ago

@rwightman the proposed solution in latest few posts maybe not exactly what you expected initially, but it essentially can acheive the same thing - saving/loading checkpoints while ignoring class names. It does add some boilerplate code, but I think if you have many models to maintain then most of this code could be factored out into some small reusable piece of code.

AlexeyKurakin commented 4 years ago

@david-berthelot I think we should decide what code changes we going to make, if any.

Seems like right now there are following ideas:

  1. Maybe add VarCollection.rename_vars as a helper method to facilitate variable renaming in VarCollections
  2. Maybe consider modifying load_var_collection to support remapping of the variables upon reloading.

@david-berthelot do you think we should implement any of these?

david-berthelot commented 4 years ago

Concerning your example: I think it's unsafe to map by dropping the class (like I mentioned previously).

  1. I'm okay to add VarCollection.rename API.
  2. I don't see the need to change load_var_collection, I'd like to see an example (two sets of strings (vc, ckpt) that won't work with the method I proposed).

To do a a class renaming for remapping, I would simply do: vc.rename([('(class_in_checkpoint)', '(class_in_model)')]), this way there's no surprise and one knows exactly what gets replaced.

rwightman commented 4 years ago

@AlexeyKurakin @david-berthelot best approach for me seems to be to maintain a fn like the proposed custom_load_var_collection. At this stage there is no http zoo/checkpoint dl & cache functionality here so other helper fns will be needed regardless.

AlexeyKurakin commented 4 years ago

After some more thinking, we decided to prototype additional API to facilitate remapping of variable names when loading checkpoint.

david-berthelot commented 4 years ago

Update: I'm going to work on a PR to prototype the ideas discussed here. Will update once I have something.

david-berthelot commented 4 years ago

Okay I've created a simple demo with minimal code changes. It also allows two-way renaming (.npz file variables names and VarCollection variable names). There's no test, almost no documentation, it's just to get feedback on the design before I put more work into polishing it. So let me know what you think!

david-berthelot commented 4 years ago

Sample output from examples/renaming.py

(EfficientNet).stem(ConvBnAct).conv(Conv2d).w                                                               1 ()
(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).conv_dw(Conv2d).w                   1 ()
(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).bn_dw(BatchNorm2D).running_mean        1 ()
(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).se(SqueezeExcite).fc1(Conv2d).b        1 ()
(EfficientNet).blocks(Sequential)[0](Sequential)[0](DepthwiseSeparable).se(SqueezeExcite).fc1(Conv2d).w        1 ()
(EfficientNet).head(Head).conv_1x1(Conv2d).w                                                                1 ()
(EfficientNet).head(Head).bn(BatchNorm2D).running_mean                                                      1 ()
(EfficientNet).head(Head).bn(BatchNorm2D).running_var                                                       1 ()
(EfficientNet).head(Head).bn(BatchNorm2D).beta                                                              1 ()
(EfficientNet).head(Head).bn(BatchNorm2D).gamma                                                             1 ()
(EfficientNet).head(Head).classifier(Linear).b                                                              1 ()
(EfficientNet).head(Head).classifier(Linear).w                                                              1 ()
+Total(12)                                                                                                 12
-------------------------------- Regex renaming --------------------------------
.stem.conv.w                            1 ()
.blocks[0][0].conv_dw.w                 1 ()
.blocks[0][0].bn_dw.running_mean        1 ()
.blocks[0][0].se.fc1.b                  1 ()
.blocks[0][0].se.fc1.w                  1 ()
.head.conv_1x1.w                        1 ()
.head.bn.running_mean                   1 ()
.head.bn.running_var                    1 ()
.head.bn.beta                           1 ()
.head.bn.gamma                          1 ()
.head.classifier.b                      1 ()
.head.classifier.w                      1 ()
+Total(12)                             12
-------------------------------- Dict renaming ---------------------------------
stem(ConvBnAct).conv(Conv2d).w                                                                1 ()
blocks(Sequential).0(Sequential).0(DepthwiseSeparable).conv_dw(Conv2d).w                      1 ()
blocks(Sequential).0(Sequential).0(DepthwiseSeparable).bn_dw(BatchNorm2D).running_mean        1 ()
blocks(Sequential).0(Sequential).0(DepthwiseSeparable).se(SqueezeExcite).fc1(Conv2d).b        1 ()
blocks(Sequential).0(Sequential).0(DepthwiseSeparable).se(SqueezeExcite).fc1(Conv2d).w        1 ()
head(Head).conv_1x1(Conv2d).w                                                                 1 ()
head(Head).bn(BatchNorm2D).running_mean                                                       1 ()
head(Head).bn(BatchNorm2D).running_var                                                        1 ()
head(Head).bn(BatchNorm2D).beta                                                               1 ()
head(Head).bn(BatchNorm2D).gamma                                                              1 ()
head(Head).classifier(Linear).b                                                               1 ()
head(Head).classifier(Linear).w                                                               1 ()
+Total(12)                                                                                   12
------------------------------ Function renaming -------------------------------
stem.conv.w                          1 ()
blocks.0.0.conv_dw.w                 1 ()
blocks.0.0.bn_dw.running_mean        1 ()
blocks.0.0.se.fc1.b                  1 ()
blocks.0.0.se.fc1.w                  1 ()
head.conv_1x1.w                      1 ()
head.bn.running_mean                 1 ()
head.bn.running_var                  1 ()
head.bn.beta                         1 ()
head.bn.gamma                        1 ()
head.classifier.b                    1 ()
head.classifier.w                    1 ()
+Total(12)                          12
-------------------------------- Saving/Loading --------------------------------
Saving dict-renamed var collection to disk.
Loading dict-renamed var collection to default var collection fails, names mismatch.
Loading dict-renamed var collection to dict-renamed var collection works.
Loading dict-renamed var collection to function-renamed var collection needs mapping.
david-berthelot commented 4 years ago

@rwightman Can you share some feedback? I'd like to commit/edit the changes this week if possible.

rwightman commented 4 years ago

@david-berthelot I switched gears and haven't swung back to jax/objax yet, but I did scan the renaming.py example and at appears to have needed flexibility and a sensible interface.