rwth-i6 / pytorch-to-returnn

Make PyTorch code runnable within RETURNN
3 stars 6 forks source link

randint: allow dynamic inputs #98

Closed vieting closed 2 years ago

vieting commented 2 years ago

So far, we only supported torch.randint for integer inputs. In this PR, this should be extended to dynamic inputs.

vieting commented 2 years ago

Right now, this crashes because in CallEntry.apply_call, the module cannot be found in this line torch_mod = naming.import_params_from_torch_namespace.get_module_by_abs_id_name(mod_abs_name). It should be found by the logic that we added to get_module_by_abs_id_name in #90 for WrappedTorchFunction like RandInt here. Here, however, part_name="mul_randint" instead of "randint", so it is not contained in mod_.module.func_name="pytorch_to_returnn.import_wrapper._torch_traced.torch.randint" and therefore no module is matched. Do you know how to generalize here?

vieting commented 2 years ago

I added updates based on the discussion in https://github.com/rwth-i6/returnn/pull/919. We cannot directly access SizeValue.originating_tensor, so the logic is a bit more involved. Also, we need to add the dependency for the converter as well (call.inputs_tensor_deps), as the feed_dict for Module.make_output_tensor_from_returnn will not be complete otherwise.

vieting commented 2 years ago

For me locally, the test passes now. I'm not yet sure why it fails here (current RETURNN master, python3.8, torch 1.8.1, tf 2.3.1...)

albertz commented 2 years ago

I added updates based on the discussion in rwth-i6/returnn#919. We cannot directly access SizeValue.originating_tensor, so the logic is a bit more involved.

Why?

vieting commented 2 years ago

I added updates based on the discussion in rwth-i6/returnn#919. We cannot directly access SizeValue.originating_tensor, so the logic is a bit more involved.

Why?

Because if a SizeValue is input to this module, we get x.get_tensor() here (see here). From there, it's not straight-forward to retrieve the SizeValue, at least not that I'm aware of.

albertz commented 2 years ago

I added updates based on the discussion in rwth-i6/returnn#919. We cannot directly access SizeValue.originating_tensor, so the logic is a bit more involved.

Why?

Because if a SizeValue is input to this module, we get x.get_tensor() here (see here). From there, it's not straight-forward to retrieve the SizeValue, at least not that I'm aware of.

But that should be straight-forward. We could simply change the is_dim to the SizeValue directly. Or maybe add a is_size_value. Or so.

vieting commented 2 years ago

For me locally, the test passes now. I'm not yet sure why it fails here (current RETURNN master, python3.8, torch 1.8.1, tf 2.3.1...)

I still don't get why this behaves differently. Especially, as it seems not related to external dependencies.

albertz commented 2 years ago

For me locally, the test passes now. I'm not yet sure why it fails here (current RETURNN master, python3.8, torch 1.8.1, tf 2.3.1...)

I still don't get why this behaves differently. Especially, as it seems not related to external dependencies.

I don't know. What if you run it like it is run like here, using nosetests?

vieting commented 2 years ago

For me locally, the test passes now. I'm not yet sure why it fails here (current RETURNN master, python3.8, torch 1.8.1, tf 2.3.1...)

I still don't get why this behaves differently. Especially, as it seems not related to external dependencies.

I don't know. What if you run it like it is run like here, using nosetests?

It's related to the fact that I just ran the single test. When running test_randint and test_randint_dynamic, I get the same error.

vieting commented 2 years ago

For me locally, the test passes now. I'm not yet sure why it fails here (current RETURNN master, python3.8, torch 1.8.1, tf 2.3.1...)

I still don't get why this behaves differently. Especially, as it seems not related to external dependencies.

I don't know. What if you run it like it is run like here, using nosetests?

It's related to the fact that I just ran the single test. When running test_randint and test_randint_dynamic, I get the same error.

In both tests, torch.randint is a WrappedTorchFunction during the traced run. However, in the second test (no matter the order), Naming.get_instance().modules is empty because WrappedIndirectModule.__getattr__ with item="randint" is never called, which would subsequently create the entries in naming.modules.

vieting commented 2 years ago

Thanks for the fix! From my perspective, the PR is ready then.