Open michelwi opened 3 years ago
The name of transform_param_axes_split_info_to_new_shape
and new_shape
is maybe confusing. It's actually the old shape here. But it's used the other way around. So in your case, it means you start with 2 input layers, and then remove one.
The way the output of transform_param_axes_split_info_to_new_shape
is used in copy_with_new_split_axes
requires actually that we put the 0 in there.
So what we want:
assert_equal(transform_param_axes_split_info_to_new_shape([[10025, 500], [10025]], (500, 10025)), [[0, 500], [10025]])
You should add that to test_transform_param_axes_split_info_to_new_shape
.
maybe transform_param_axes_split_info_to_new_shape is used elsewhere
Just check that. I think not.
Is it save to allow the case new_parts[0] == 0 ...
Why not? Also, allowing sth which was not allowed before can not possible break anything.
Now the first parts ... hits my newly defined third condition ..., then the second parts ... overwrites it ....
That's way too ugly. You should not depend on assuming a specific order of the axes for your heuristic to work. (Yes, you can already now create strange edge cases where it will break depending on the order, but ignore those. Do not make it explicitly depending on such behavior.)
I would add sth like this to the first loop:
elif new_dim in parts:
dim_diff[new_dim] = new_dim
And in the second loop, in the sum ... != new_dim
if-branch, I would add another check:
if new_dim in new_parts:
new_parts = ...
else:
... (as before)
Also change assert new_parts[j] > 0
to assert new_parts[j] >= 0
.
Can we rewrite the whole mechanism to not use a dict dim_diff ...
I'm quite sure this would break other cases.
Maybe instead of dict[int,int], it could be dict[int,set[int]].
However, I also would try to not make this too complicated.
Note that there will always be cases which this does not fully cover, no matter what you do. I think it's fine now to add one or two more heuristics to cover your case but we should not make it too complicated.
Maybe there is a better way how to infer the old_axes_splits
or to copy over the parameters.
Is this resolved?
Note that we also don't really need to invest so much energy into making the heuristic work correct in all cases (which is anyway not possible). We could maybe instead also just invest some energy into making it the right way so that it always correct, in a clean and predictable way, without relying on such a heuristic. This should certainly be possible, right?
E.g. the problem this heuristic tries to solve is, to recover the old shape split information, which is not available at this point anymore. Maybe we could just store this in the checkpoint or somewhere else?
Problem Statement
I would like to extend an already existing layer with an additional input. In my example I have trained an attention-based encoder-decoder model and now I would like to add an external LM to the inputs of the Softmax layer:
My returnn config for training the checkpoint was:
Now I continue training with this extended layer while importing the existing checkpoint with the
preload_from_files
mechanics:First Attempt
I am on commit face0c3be7d5336ac38c106ccc6606af70bd7ac9 wich I extended with the changes introduced in #412
The relevant variables in
returnn/tf/util/basic.py
functiontransform_param_axes_split_info_to_new_shape
areHere is the Stack Trace I get
``` Unhandled exceptionOk, so in the loop of https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L155 the first
parts = [10025, 500]
hits neither condition and only the secondparts = [10025]
setsdim_diff = {10025: 10025}
so that we end up in https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L167 and then the heuristics fail.Second Attempt
Then I added my case to the end of the loop of https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L155
Now I get the following variables
And here is the new Stack Trace
``` Unhandled exceptionNow the first
parts = [10025, 500]
hits my newly defined third condition and setsdim_diff = {10025: 0, 500: 500}
, then the secondparts = [10025]
overwrites it todim_diff = {10025: 10025, 500: 500}
.new_parts
for the first dim should have been[0, 500]
but due to the overwriting it is now[10025. 500]
and we trigger https://github.com/rwth-i6/returnn/blob/face0c3be7d5336ac38c106ccc6606af70bd7ac9/returnn/tf/util/basic.py#L175 which fortunately saves us and setsnew_parts = [0, 500]
.Now only the assertion
new_parts[0] > 0
fails as expected.Questions
new_parts[0] == 0
? For me yes, the logic incopy_with_new_split_axes
will work as I expect, but maybetransform_param_axes_split_info_to_new_shape
is used elsewhere.dim_diff
but to infer the new shapes in order as we progress through the lists? This would alleviate the issues with square layers.