k2-fsa / snowfall

Moved to https://github.com/k2-fsa/icefall
Apache License 2.0
143 stars 42 forks source link

Torchscriptable Conformer + high-level "simple" object for decoding, alignments, posteriors, plotting them, etc. #206

Closed pzelasko closed 3 years ago

pzelasko commented 3 years ago

Note: this is unfinished, I put together multiple code snippets but didn't polish it in any way. It could be a part of Icefall with a bit of work and documentation.

You can use it e.g. in a jupyter notebook to plot posteriors and alignments, note the high level methods (which all accept cuts as inputs):

You can get the plots from #203 with this.

danpovey commented 3 years ago

I don't think I want to go in this direction of having things all bound together in classes -- at least for the time being. I want to go with simple utility functions with interfaces as small as possible. However, I am open to merging this simply as a useful reference for things we might need to do, like plotting.

pzelasko commented 3 years ago

That makes sense to me TBH because I am not even sure how to make a class like this generic enough to handle the different types of models, topologies, techniques, etc. that we're going to introduce. Anyway, I find it helpful to work with the "current best" model in other projects. I'll test it a bit more to make sure everything works as intended and I'll merge then.

danpovey commented 3 years ago

Thanks!

On Thu, Jun 3, 2021 at 9:22 PM Piotr Żelasko @.***> wrote:

That makes sense to me TBH because I am not even sure how to make a class like this generic enough to handle the different types of models, topologies, techniques, etc. that we're going to introduce. Anyway, I find it helpful to work with the "current best" model in other projects. I'll test it a bit more to make sure everything works as intended and I'll merge then.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/pull/206#issuecomment-853864525, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO24OIVGOZEAL4N4ICLTQ56YHANCNFSM457S46ZQ .

csukuangfj commented 3 years ago

I find it helpful to work with the "current best" model in other projects.

I would propose to let the code support any PyTorch model, at least for those supported by Torch Script.

The user only needs to provide a .pt file, which contains everything needed to run the model.

I just write a small demo (see below) to show that the idea is feasible. You can see that with demo.pt at hand, we don't need the definition of the model to run its forward function.

I believe the following small utilities would be helpful:

(1) compute-post

(2) decode

(3) show-ali

Screen Shot 2021-06-07 at 4 24 03 PM
#!/usr/bin/env python3

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.linear(x)

@torch.no_grad()
def main():
    m = Model(in_dim=2, out_dim=3)

    x = torch.tensor([1, 2.0])
    y = m(x)

    script_module = torch.jit.script(m)
    script_module.save('demo.pt')

    new_m = torch.jit.load('demo.pt')
    new_y = new_m(x)
    print(y)
    print(new_y)

if __name__ == '__main__':
    main()
danpovey commented 3 years ago

Perhaps someone can test whether our current models are supported by TorchScript, or at least whether it would be possible to make them supported?

pzelasko commented 3 years ago

I was just able to convert the Conformer to torchscript with some changes, I'll make it a part of this PR.

danpovey commented 3 years ago

Great!

On Tue, Jun 8, 2021 at 12:13 AM Piotr Żelasko @.***> wrote:

I was just able to convert the Conformer to torchscript with some changes, I'll make it a part of this PR.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/pull/206#issuecomment-856072840, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3GE2QSYSSQQYNEP5TTRTVZLANCNFSM457S46ZQ .

pzelasko commented 3 years ago

OK, a summary of this PR:

Things that still don't work:

I'm not sure if I have the capacity to work on this further for now -- in any case, Conformer is now scriptable which should open the way for others.

pzelasko commented 3 years ago

Saving models to TorchScript works during training with --torchscript-epoch <start-saving-epoch> flag, without any issues of speed, because it converts to TorchScript just before storing. The issues are when we train using a ScriptModule (--torchscript true flag).

It is OK to merge from my side -- please review and merge if it's adequate.

pzelasko commented 3 years ago

One last remark which I forgot to mention -- I did only very naive benchmarking by running in Jupyter the following snippet:

%%timeit
with torch.no_grad():
    model(features, supervisions)

The improvement from normal to TorchScripted model was small -- 140ms vs 130ms. It used a V100 GPU with ~30 cuts in the batch. So it's only the training time when I noticed the slowdown.

mthrok commented 3 years ago

One last remark which I forgot to mention -- I did only very naive benchmarking by running in Jupyter the following snippet:

%%timeit
with torch.no_grad():
    model(features, supervisions)

The improvement from normal to TorchScripted model was small -- 140ms vs 130ms. It used a V100 GPU with ~30 cuts in the batch. So it's only the training time when I noticed the slowdown.

FYI: Though originally TorchScript was advertised for performance, it is now mainly solving the problem of deployment. The resulting object is deployable to C++/iOS/Android. The performance improvement efforts were moved to AI compiler, so in general we can't expect performance improvement. I say it's lucky if you get performance improvement if you get any.

This might not be relevant to your application, but when scripting a model, it is possible to perform irreversible operation. For example, recently I added TorchScript-able wav2vec2 to torchaudio. For the sake of supporting quantization, I added a hook for scripting that removes weight normalization forward hook. My rational was that the model was mainly intended for inference so removing a hook is fine. However if a model is scripted during the training, then the wav2vec2 model from torch audio is not compatible with snowfall in kind of unexpected way.

Since torch script object file is only architecture and pareters, it feels to me that creating a tool that makes scripted model from training checkpoint file is simpler.

But however I do not know the design principles of snowfall or the context of this work, if that's desired, I think it's okay.

mthrok commented 3 years ago

If you are looking for a way to speed up training, quantization aware training is one approach. I heard there was a case where it both improved training time AND accuracy at the same time.

pzelasko commented 3 years ago

Thanks @mthrok, that makes a lot of sense. I wondered if people use TorchScript to speed up the training but now it's clear it's not the case. BTW could you elaborate on the AI compiler? Is it Glow, or sth else?

The --torchscript-epoch option is basically what a checkpoint conversion tool would have done -- except to write such a tool, we would need to provide all the info such as architecture, hparams, etc. to it, so I guess the hope is that we can just store a torchscripted model in the training script not to have to know all the hparams needed to instantiate the model for downstream applications. But maybe @danpovey and @csukuangfj will have a different view. In any case we should also support weight averaging before storing the torchscripted model as it consistently improves the results.

mthrok commented 3 years ago

Thanks @mthrok, that makes a lot of sense. I wondered if people use TorchScript to speed up the training but now it's clear it's not the case. BTW could you elaborate on the AI compiler? Is it Glow, or sth else?

Yeah I think one of them. (but I do not know much so please take it with a grain of salt 😥) There is also torch.fx.

pzelasko commented 3 years ago

There’s also an ONNX exporter here that converts TorchScript modules https://pytorch.org/docs/stable/onnx.html

maybe it’s worth looking into..

danpovey commented 3 years ago

Thanks! I'll merge so we don't get too out of date..