k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.1k stars 214 forks source link

adding inference_mode() causes "Inference tensors do not track version counter" #1210

Open Slyne opened 1 year ago

Slyne commented 1 year ago

Hi K2 experts,

I'm trying to add the k2 fast beam decoder to my acoustic model. It works great with torch.inference_mode() set to false. But it raises the below error when I set it to True.

File "/home/slyne/projects/zoom/zoom_asr_transducer/src/k2_beam_search.py", line 152, in fast_beam_search
    lattice = decoding_streams.format_output(encoder_out_lens.tolist())
  File "/home/slyne/anaconda3/envs/zoom_asr/lib/python3.8/site-packages/k2/rnnt_decode.py", line 186, in format_output
    fsa = Fsa(ragged_arcs)
  File "/home/slyne/anaconda3/envs/zoom_asr/lib/python3.8/site-packages/k2/fsa.py", line 230, in __init__
    self.labels_version = self._tensor_attr["labels"]._version
RuntimeError: Inference tensors do not track version counter.

This flag is important in my use case since it greatly accelerates the acoustic model inference speed. Not sure if this flag must be set False in k2?

Any suggestion is appreciated! Thanks!

csukuangfj commented 1 year ago

From https://pytorch.org/docs/stable/generated/torch.inference_mode.html

Code run under this mode gets better performance by disabling view tracking and version counter bumps

Looks like we need alternatives to deal with _version as it does not exist in inference mode.

I suggest that you use torch.no_grad() as a workaround for now.

Slyne commented 1 year ago

From https://pytorch.org/docs/stable/generated/torch.inference_mode.html

Code run under this mode gets better performance by disabling view tracking and version counter bumps

Looks like we need alternatives to deal with _version as it does not exist in inference mode.

I suggest that you use torch.no_grad() as a workaround for now.

Sure.