rwth-i6 / returnn_common

Common building blocks for RETURNN configs, such as models, training concepts, etc
7 stars 4 forks source link

Definition of losses, `mark_as_loss`? #56

Open albertz opened 2 years ago

albertz commented 2 years ago

Currently to define some tensor (layer ref) as loss, you call mark_as_loss on it. The idea was to be somewhat analoge to when you call loss.backward() in PyTorch.

Common code in PyTorch looks like this (see here):

class MyModel(nn.Module):
  ...

model = MyModel(...)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for input, targets in dataset:
  optimizer.zero_grad()
  output = model(input)
  loss = loss_fn(output, targets)
  loss.backprop()
  optimizer.step()

So the loss.backprop() call and also the definition of loss itself is somewhat separate from MyModel. In MyModel, you would not really define the loss. So this is usually decoupled.

This is not how it would work for returnn-common currently, where it cannot be separated. When you call make_root_net_dict (#44) on model, it just calls model(...) (using extern data) and that's it.

So the current API (make_root_net_dict) implies that the loss is defined inside the model, inside MyModel, and cannot be decoupled. Or can it?

I think we should be able to decouple it, if we want to. Any module (e.g. Transformer #53) should just define the model and not be specific about losses.

The question is how exactly.

Maybe we can extend make_root_net_dict to pass train_loss as well or so.

(I open a separate issue on this because #38 is just on the aspect of what loss functions or modules we want and their naming and usage conventions.)

albertz commented 2 years ago

I just checked ESPNet as an example. See here and here. It looks like they don't separate this and have the loss logic as part of the model.

I wonder though whether this is a good idea or not.

Certainly some losses are related to specific models, e.g. transducer or CTC (at least the original literature introduces both together).

I also wonder how other frameworks are doing this.

albertz commented 2 years ago

Fairseq seems to have it decoupled. See here for the model which does not have anything about the loss. See train_step, where you find that the model and criterion are separated. It is called like this:

loss, sample_size, logging_output = criterion(model, sample)
optimizer.backward(loss)

So this kind of assumes a single criterion (loss). Although you could simply add up multiple criteria into a single one (see CompositeLoss).

See CrossEntropyCriterion as an example. It gets the model and calls the model and then computes the loss:

net_output = model(**sample["net_input"])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)

sample is the minibatch dict consisting of net_input, maybe target and others. In compute_loss, you find that it calls model.get_targets(sample, net_output). So this is where the model is still slightly coupled to the loss. The default BaseFairseqModel.get_targets is just return sample["targets"].

albertz commented 2 years ago

I would propose to extend our make_root_net_dict API to be able to decouple the losses.

I would also propose to remove mark_as_loss. When you look at some model code, I think it would be bad if there could be some mark_as_loss somewhere deep hidden inside it. It should be explicit. This implies that make_root_net_dict (or some other API) would have to collect all losses explicitly.

The details of the extended API need to be worked out.

albertz commented 2 years ago

Currently make_root_net_dict usage looks like this:

model = MyModel(...)
network = make_root_net_dict(model, "data")

Where "data" refers to extern data here. Any arguments you pass here after model are passed on to the model but wrapped via get_extern_data. (See #44.)

In this API, what model returns does not really matter. It could be a single layer ref or also a tuple.

I would propose to define a Criterion or Loss interface like:

(model_output, *targets, **targets_kwargs) -> loss

And then maybe:

model = MyModel(...)
loss = cross_entropy_logits  # should comply with the loss interface

network = make_root_net_dict(
  model, "data",
  train_loss=make_loss(loss, "classes"))

The train_loss could also accept a tuple or dict. make_loss would wrap the passed function or module instance loss just in the same way as make_root_net_dict does for model but it also adds model_output as the first argument (whatever it is). This example uses cross_entropy_logits via #38. So this implies that model returns logits. It also assumes that cross_entropy_logits gets two arguments (out_logits, targets) (the order is important).

If model would return sth more custom here (e.g. a tuple or so), you probably would implement a custom loss here, like:

def loss(model_output, target1, target2, ...):
  ...
  return combined_loss

And then:

network = make_root_net_dict(
  model, "data",
  train_loss=make_loss(loss, "classes1", "classes2", ...))
albertz commented 2 years ago

Note that we have some other things coupled to the model definition as well (but this is usually the same in many other frameworks):

In RETURNN, we are also specifically proud of having a unified way to define the network both for training and decoding with search (ChoiceLayer). Although this mostly only works for simple cases. For many more complex cases (e.g. min WER training), ChoiceLayer with search enabled during training is still very useful.

Related is #18 on training behavior and a potential train flag on this level, and also on search.

Further, some models might define other auxiliary losses, or other multi-task losses. E.g. we often have some auxiliary CTC loss on top of the encoder. This is maybe going to be extended in the future by many more auxiliary losses, and maybe also more local losses.

Considering this, maybe it makes it unnecessarily complicated if we strictly try to decouple the loss definitions from the model definitions. Because then the root module (main model) would need to return every intermediate output as well to be able to define all the auxiliary losses.

But this implies that all loss options on all these variations (auxiliary losses, multi-task losses, unsupervised losses, and the main loss, where many variations are also possible) needs to be passed to the model somehow. Which is also not nice, as many models usually already have many other options, and this could blow up.

Maybe we can also have both. The model somehow can define maybe one or more outputs. Maybe also specifying the output type (logits, log prob, prob or so). And then losses can be defined on all those outputs, or on some, or only on the final output.

Maybe some option like loss: LossCollector can be passed to some modules and they can then call loss.add_unsupervised(...) or loss.add_supervised(...)? But there could be multiple targets, and this somehow is maybe unintuitive and not generic...

If the models define the losses themselves (e.g. via the existing mark_as_loss), despite the input data, we also need to pass the targets. Or this could be optional, and if no targets are passed, no losses are defined. But I still don't like mark_as_loss and its global side-effect to add this loss, which cannot really be handled further. The caller of such module can not operate on the losses. I think it would be cleaner if the module needs to return the losses somehow.

But I don't like it too much that the forward does both then. Maybe there could be a separate compute_losses next to the forward, and make_root_net_dict or so can automatically recursively call all such functions.

We could also make use of context managers in some way. E.g. if we stick to mark_as_loss, there could be a context manager which collects all the mark_as_loss call, sth like:

def forward(self, x):
  with collect_losses() as losses:
    y = self.encoder(x)  
  for loss in losses.collected():
    y = y + loss
  return y

So you would at least have the possibility to make some special use of the losses, if that would ever be needed. Maybe that could even be the way how mark_is_loss is implemented internally.

albertz commented 2 years ago

To summarize my previous post: I'm unsure. Maybe mark_as_loss is actually fine.

albertz commented 2 years ago

We maybe should think about some of the more unusual settings. E.g. not just standard frame-wise (or label-wise) cross entropy, but e.g. min expected WER training. Or also some meta learning. How would this look like? And this probably should be decoupled from the model definition, I assume?

albertz commented 2 years ago

Min exp WER needs some decoder which performs search, and we get the beam and beam scores.

49 is relevant here for the common interface of the decoder, which should support that.

Then basically:

log_prob, labels = decoder(..., search=True)
prob = exp(log_prob)
loss = reduce_sum(
  axis=beam,
  value=prob * edit_distance(labels, targets))

Or with second pass:

log_prob, labels = decoder(..., search=True)
log_prob, labels = decoder(..., search=False, targets=labels)
prob = exp(log_prob)
loss = reduce_sum(
  axis=beam,
  value=prob * edit_distance(labels, targets))
albertz commented 2 years ago

I think this boils down to one fundamental question: whether the loss definition should be decoupled from the model definition or not.

mark_as_loss and basically the way we defined losses in RETURNN coupled the model definition and the losses.

I think it would be more clean to decouple it more.

However, I think there can still be valid cases where it can be more coupled, or would just complicate everything if the loss is not locally defined alongside with the module.

So I think both should be possible. We should prefer it to be decoupled when coupling would not really be needed, e.g. for model parts like Transformer etc. But otherwise Modules also can define losses when this makes sense.

Generic regularizing local losses like L2 param norm should be handled in a more generic way (#59).

albertz commented 2 years ago

Ok, so to come back to this: On the question whether to separate losses or have them alongside the module: I think we should support both ways, as it was argued before.

WIth the decision to make all dim tags and axes arguments explicit (#17), we also need to think about extending make_root_net_dict that way.

I currently think about getting rid of such a single function make_root_net_dict and make it more flexible and explicit. Maybe by some context manager scope. Maybe the example in test_simple_net_module_explicit_root_ctx is already kind of like that, i.e.:

net = Net(...)

with nn.NameCtx.new_root() as name_ctx:
  out = net(nn.get_extern_data("data"), name=name_ctx)
  name_ctx.make_default_output(out)
  net_dict = name_ctx.make_net().make_net_dict_raw()

Note that the root network (net here) or rather its module class (Net here) is supposed to be reusable in other context, so that is why we pass extern data explicitly to it. But this is not really new, this is what we already did before. I just want to emphasize again that this is what we want. So this is good.

One aspect I'm undecided here is the way to declare the net module as the root network. In the example, it happens by passing name=name_ctx where name_ctx is the root name context. This might be unintuitive.

Also, currently we only support one root module. But maybe we want to support multiple? E.g. some losses would not be part of the root module but separate. This should be supported, as it was discussed here. So like:

with nn.NameCtx.new_root() as name_ctx:
  net = Net(...)
  out = net(nn.get_extern_data("data"))

  loss1 = nn.cross_entropy(out, net.get_extern_data("classes"))
  loss1.mark_as_loss()

  loss_model = LossModel(...)
  loss2 = loss_model(out, ...)
  loss2.mark_as_loss()

In this example, loss1 would not have further parameters, so the naming doesn't matter. But loss2 in this example would be another module with potential other trainable parameters. So the absolute layer names matter. So should this be another root module, just like net? Is this easy to do? And what about naming conflicts then? Or the user would make net the root network, and loss2 gets some other specific explicit sub name ctx, like name=NameCtx(name='loss2', parent=name_ctx) but in some more clean way, or just name="loss2"?

Just as a reminder: The assignment of a module call (or module itself) to some unique name context is necessary to get unique names for the parameters. This is the only reason. And the unique names for parameters is necessary for the checkpoint file.

albertz commented 2 years ago

Note: The handling of extern data was updated, and there is no make_root_net_dict anymore. See https://github.com/rwth-i6/returnn_common/issues/44#issuecomment-1002686988.

albertz commented 2 years ago

Related is also having the training loop and stages explicit (#96).