rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
348 stars 130 forks source link

CompareLayer: allow_broadcast_all_sources missing #798

Closed robin-p-schmitt closed 2 years ago

robin-p-schmitt commented 2 years ago

The CompareLayer does not yet set the allow_broadcast_all_sources parameter of Data.get_common_data but instead leaves it on the default NotSpecified regardless of whether out_shape or out_type are set. The CompareLayer uses Data.get_common_data both in __init__ and in get_out_data_from_opts. For the latter, we could use the same code as the CombineLayer:

   if sources:
      allow_broadcast_all_sources = NotSpecified
      if out_shape is not None:
        allow_broadcast_all_sources = True
      elif out_type and isinstance(out_type, dict) and ("shape" in out_type or "dim_tags" in out_type):
        allow_broadcast_all_sources = True
      out_type_.update(
        Data.get_common_data(
          [s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs())

For the former, however, we currently cannot check whether out_shape is set because out_shape is not an attribute of CompareLayer. Should we set self.out_shape in LayerBase? What do you think? @albertz @Zettelkasten

albertz commented 2 years ago

I don't understand. Why does it need to be an attribute? Why don't you just check for it?

But also, instead of using common_data = ..., better anyway use common_data = self.output.