Open albertz opened 2 years ago
I just checked ESPNet as an example. See here and here. It looks like they don't separate this and have the loss logic as part of the model.
I wonder though whether this is a good idea or not.
Certainly some losses are related to specific models, e.g. transducer or CTC (at least the original literature introduces both together).
I also wonder how other frameworks are doing this.
Fairseq seems to have it decoupled. See here for the model which does not have anything about the loss.
See train_step
, where you find that the model and criterion are separated. It is called like this:
loss, sample_size, logging_output = criterion(model, sample)
optimizer.backward(loss)
So this kind of assumes a single criterion (loss). Although you could simply add up multiple criteria into a single one (see CompositeLoss
).
See CrossEntropyCriterion
as an example. It gets the model
and calls the model
and then computes the loss:
net_output = model(**sample["net_input"])
loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
sample
is the minibatch dict consisting of net_input
, maybe target
and others.
In compute_loss
, you find that it calls model.get_targets(sample, net_output)
. So this is where the model is still slightly coupled to the loss. The default BaseFairseqModel.get_targets
is just return sample["targets"]
.
I would propose to extend our make_root_net_dict
API to be able to decouple the losses.
I would also propose to remove mark_as_loss
. When you look at some model code, I think it would be bad if there could be some mark_as_loss
somewhere deep hidden inside it. It should be explicit. This implies that make_root_net_dict
(or some other API) would have to collect all losses explicitly.
The details of the extended API need to be worked out.
Currently make_root_net_dict
usage looks like this:
model = MyModel(...)
network = make_root_net_dict(model, "data")
Where "data"
refers to extern data here. Any arguments you pass here after model
are passed on to the model
but wrapped via get_extern_data
. (See #44.)
In this API, what model
returns does not really matter. It could be a single layer ref or also a tuple.
I would propose to define a Criterion
or Loss
interface like:
(model_output, *targets, **targets_kwargs) -> loss
And then maybe:
model = MyModel(...)
loss = cross_entropy_logits # should comply with the loss interface
network = make_root_net_dict(
model, "data",
train_loss=make_loss(loss, "classes"))
The train_loss
could also accept a tuple or dict.
make_loss
would wrap the passed function or module instance loss
just in the same way as make_root_net_dict
does for model
but it also adds model_output
as the first argument (whatever it is).
This example uses cross_entropy_logits
via #38. So this implies that model
returns logits. It also assumes that cross_entropy_logits
gets two arguments (out_logits, targets)
(the order is important).
If model
would return sth more custom here (e.g. a tuple or so), you probably would implement a custom loss
here, like:
def loss(model_output, target1, target2, ...):
...
return combined_loss
And then:
network = make_root_net_dict(
model, "data",
train_loss=make_loss(loss, "classes1", "classes2", ...))
Note that we have some other things coupled to the model definition as well (but this is usually the same in many other frameworks):
tf.summary
)In RETURNN, we are also specifically proud of having a unified way to define the network both for training and decoding with search (ChoiceLayer
). Although this mostly only works for simple cases. For many more complex cases (e.g. min WER training), ChoiceLayer
with search enabled during training is still very useful.
Related is #18 on training behavior and a potential train flag on this level, and also on search.
Further, some models might define other auxiliary losses, or other multi-task losses. E.g. we often have some auxiliary CTC loss on top of the encoder. This is maybe going to be extended in the future by many more auxiliary losses, and maybe also more local losses.
Considering this, maybe it makes it unnecessarily complicated if we strictly try to decouple the loss definitions from the model definitions. Because then the root module (main model) would need to return every intermediate output as well to be able to define all the auxiliary losses.
But this implies that all loss options on all these variations (auxiliary losses, multi-task losses, unsupervised losses, and the main loss, where many variations are also possible) needs to be passed to the model somehow. Which is also not nice, as many models usually already have many other options, and this could blow up.
Maybe we can also have both. The model somehow can define maybe one or more outputs. Maybe also specifying the output type (logits, log prob, prob or so). And then losses can be defined on all those outputs, or on some, or only on the final output.
Maybe some option like loss: LossCollector
can be passed to some modules and they can then call loss.add_unsupervised(...)
or loss.add_supervised(...)
? But there could be multiple targets, and this somehow is maybe unintuitive and not generic...
If the models define the losses themselves (e.g. via the existing mark_as_loss
), despite the input data, we also need to pass the targets. Or this could be optional, and if no targets are passed, no losses are defined.
But I still don't like mark_as_loss
and its global side-effect to add this loss, which cannot really be handled further. The caller of such module can not operate on the losses. I think it would be cleaner if the module needs to return the losses somehow.
But I don't like it too much that the forward
does both then. Maybe there could be a separate compute_losses
next to the forward
, and make_root_net_dict
or so can automatically recursively call all such functions.
We could also make use of context managers in some way. E.g. if we stick to mark_as_loss
, there could be a context manager which collects all the mark_as_loss
call, sth like:
def forward(self, x):
with collect_losses() as losses:
y = self.encoder(x)
for loss in losses.collected():
y = y + loss
return y
So you would at least have the possibility to make some special use of the losses, if that would ever be needed.
Maybe that could even be the way how mark_is_loss
is implemented internally.
To summarize my previous post: I'm unsure. Maybe mark_as_loss
is actually fine.
We maybe should think about some of the more unusual settings. E.g. not just standard frame-wise (or label-wise) cross entropy, but e.g. min expected WER training. Or also some meta learning. How would this look like? And this probably should be decoupled from the model definition, I assume?
Min exp WER needs some decoder which performs search, and we get the beam and beam scores.
Then basically:
log_prob, labels = decoder(..., search=True)
prob = exp(log_prob)
loss = reduce_sum(
axis=beam,
value=prob * edit_distance(labels, targets))
Or with second pass:
log_prob, labels = decoder(..., search=True)
log_prob, labels = decoder(..., search=False, targets=labels)
prob = exp(log_prob)
loss = reduce_sum(
axis=beam,
value=prob * edit_distance(labels, targets))
I think this boils down to one fundamental question: whether the loss definition should be decoupled from the model definition or not.
mark_as_loss
and basically the way we defined losses in RETURNN coupled the model definition and the losses.
I think it would be more clean to decouple it more.
However, I think there can still be valid cases where it can be more coupled, or would just complicate everything if the loss is not locally defined alongside with the module.
So I think both should be possible. We should prefer it to be decoupled when coupling would not really be needed, e.g. for model parts like Transformer etc. But otherwise Module
s also can define losses when this makes sense.
Generic regularizing local losses like L2 param norm should be handled in a more generic way (#59).
Ok, so to come back to this: On the question whether to separate losses or have them alongside the module: I think we should support both ways, as it was argued before.
WIth the decision to make all dim tags and axes arguments explicit (#17), we also need to think about extending make_root_net_dict
that way.
I currently think about getting rid of such a single function make_root_net_dict
and make it more flexible and explicit. Maybe by some context manager scope. Maybe the example in test_simple_net_module_explicit_root_ctx
is already kind of like that, i.e.:
net = Net(...)
with nn.NameCtx.new_root() as name_ctx:
out = net(nn.get_extern_data("data"), name=name_ctx)
name_ctx.make_default_output(out)
net_dict = name_ctx.make_net().make_net_dict_raw()
Note that the root network (net
here) or rather its module class (Net
here) is supposed to be reusable in other context, so that is why we pass extern data explicitly to it. But this is not really new, this is what we already did before. I just want to emphasize again that this is what we want. So this is good.
One aspect I'm undecided here is the way to declare the net
module as the root network. In the example, it happens by passing name=name_ctx
where name_ctx
is the root name context. This might be unintuitive.
Also, currently we only support one root module. But maybe we want to support multiple? E.g. some losses would not be part of the root module but separate. This should be supported, as it was discussed here. So like:
with nn.NameCtx.new_root() as name_ctx:
net = Net(...)
out = net(nn.get_extern_data("data"))
loss1 = nn.cross_entropy(out, net.get_extern_data("classes"))
loss1.mark_as_loss()
loss_model = LossModel(...)
loss2 = loss_model(out, ...)
loss2.mark_as_loss()
In this example, loss1
would not have further parameters, so the naming doesn't matter. But loss2
in this example would be another module with potential other trainable parameters. So the absolute layer names matter. So should this be another root module, just like net
? Is this easy to do? And what about naming conflicts then? Or the user would make net
the root network, and loss2
gets some other specific explicit sub name ctx, like name=NameCtx(name='loss2', parent=name_ctx)
but in some more clean way, or just name="loss2"
?
Just as a reminder: The assignment of a module call (or module itself) to some unique name context is necessary to get unique names for the parameters. This is the only reason. And the unique names for parameters is necessary for the checkpoint file.
Note: The handling of extern data was updated, and there is no make_root_net_dict
anymore. See https://github.com/rwth-i6/returnn_common/issues/44#issuecomment-1002686988.
Related is also having the training loop and stages explicit (#96).
Currently to define some tensor (layer ref) as loss, you call
mark_as_loss
on it. The idea was to be somewhat analoge to when you callloss.backward()
in PyTorch.Common code in PyTorch looks like this (see here):
So the
loss.backprop()
call and also the definition ofloss
itself is somewhat separate fromMyModel
. InMyModel
, you would not really define the loss. So this is usually decoupled.This is not how it would work for returnn-common currently, where it cannot be separated. When you call
make_root_net_dict
(#44) onmodel
, it just callsmodel(...)
(using extern data) and that's it.So the current API (
make_root_net_dict
) implies that the loss is defined inside the model, insideMyModel
, and cannot be decoupled. Or can it?I think we should be able to decouple it, if we want to. Any module (e.g.
Transformer
#53) should just define the model and not be specific about losses.The question is how exactly.
Maybe we can extend
make_root_net_dict
to passtrain_loss
as well or so.(I open a separate issue on this because #38 is just on the aspect of what loss functions or modules we want and their naming and usage conventions.)