Open rwightman opened 4 years ago
Another addition here instead of creating isssue spam.
Something else that cropped up when brining in weights from elsewhere, the Objax combo of NCHW, HWIO, NCHW for conv kernels is different from both Tensorflow/Keras defaults (channels_last w/ NHWC, HWIO, NHWC) and PyTorch (NCHW, OIHW, NCHW) .. Objax matches the less common mode of TF/Keras w/ channels_first which leaves the kernel shapes as is. Other JAX interfaces match the TF defaults. Curious why Objax didn't do that, or fully match PyTorch kernel shapes if deciding to use NCHW for data?
Also the shapes of single feature dim vars like conv bias (C, 1, 1) , linear bias (C,), bn bias (1, C, 1, 1), bn gamma (1, C, 1, 1) differ both among the different layers in Objax and from PyTorch which is (C,) for all, and I think TF in normal use.
On the whole, when there is consensus for what the correct implementation of something should be, objax tries to follow that. For example, there is a definitive resnet-18, and so we try to be exact there.
But when there's no consensus (for example, whether it should be called kernel
or weight
, or whether the best default parameter should be .9
or .99
) then we pick something that we think is most reasonable. This is an entirely subjective definition, because there's no "correct" value of momentum that works for every application---it has to be tuned. If there's some strong evidence that a momentum of .9 is objectively better (say, across a large suite of tasks) then it would make sense to change.
But on the whole, we aren't going to try and be consistent with either tensorflow or pytorch just for the sole purpose of being consistent with one. When they diverge, it's likely because people have different subjective definitions of best. We use w
and b
because that's what they're typically called in pseudocode in papers, and they're short and descriptive. Is it clearly better than weight
and bias
? Nope. But neither is the converse.
Thinking about beta/gamma definitions here might make sense: they're arbitrary letters with no inherent meaning. If @david-berthelot had a good reason to pick the ones he did, I'm not sure. In this one case switching them around to match tensorflow doesn't sound like a bad idea.
For the weight shapes, there's actually a good reason for this. @AlexeyKurakin had found somewhere that JaX is actually fastest when using this combination of parameters. And so we decided to stick with this one because it's best for JaX.
The reason the weights are assigned with the shape in them is mostly just to keep everything clean on the implementation, to make explicit where the values are supposed to go. It makes it slightly less user-friendly because now someone has to know to reshape to put (1,C,1,1) but it also makes it very clear what's going on: the value is being added along the second axis. If the library does it for you then you can be not sure where exactly it's going on.
This second choice is another subjective decision. It may not be best, but it's what we've started with. If you think there's a good case to be had that it should be flat then we'd be happy to hear it.
@carlini thanks for the explanations.
Taken in aggregate these details are an interface of sorts, a low-level one, but something certain users will encounter. I'm not sure standing out / being different here is an asset. In some respects it turned out to be easier to move model weights between PyTorch and Tensorflow than either to Objax which was surprising.
Re the conv kernel layout + data format combo. With XLA on the backend I would have expected that decision not to matter much at all from a performance standpoint. I figured XLA would remap as it sees fit for the target accelerator. Modern GPU will end up channels last to take advantage of the optimized TensorCore kernels, TPU would have their own pref, etc.
I'm not sure if this was by design, or just the way it worked out, but Objax forges in own path for many parameter/variable, state naming/order conventions, default arguments etc.
While it may seem trivial if one is just using Objax, but adds some cognitive overhead when you're working with multiple frameworks, moving weights around, some components that work with multiple modeling interfaces.
Some examples:
Objax BatchNorm eps/momentum defaults are diff from both TF/Keras and PyTorch. Momentum isn't that important unless training, but eps impacts existing weight compatibility if not matched.
Lots of layer variable names and their creation orderings are different.
Conv2d
.b
and.w
.kernel
and.bias
.weight
and.bias
BatchNorm
.running_mean
,.running_var
.beta
,.gamma
.moving_mean
,.moving_variance
,.gamma
,.beta
.weight
,.bias
,running_mean
,running_var
Ordering of variables, if iterating over in creation order Objax is often bias first, almost every other framework I'm used to usually has weight (or gamma/scale/etc equiv) then bias in creation order (and thus often iteration order).
I'm not sure if any of this is still in flux, if so, could help to align more conventions with one of the existing options.