DigitalPhonetics / IMS-Toucan

Controllable and fast Text-to-Speech for over 7000 languages!
Apache License 2.0
1.4k stars 158 forks source link

For info: Toucan doesn't torch.compile() #156

Closed tomschelsen closed 3 months ago

tomschelsen commented 1 year ago

More for information (in case anyone wanted to try it) than a bug, as I think the torch.compile() situation isn't fully stabilised yet, I tried to use this new facility of Pytorch 2 on IMS-Toucan, with the initial goal of benchmarking different compiler backends (default "inductor" versus the more recently released 3rd-party "hidet") and different parameters.

I tried to compile ToucanTTSInterface as it is the higher level torch.nn.Module that I am currently using for inference.

I first encountered the following (I edited the backtrace to remove non-relevant paths and calls infos) :

[2023-08-21 08:52:07,749] torch._dynamo.symbolic_convert: [WARNING] IMS-Toucan/Preprocessing/TextFrontend.py <function english_text_expansion at 0x7f8b1ce7a5e0> [UserDefinedObjectVariable(ArticulatoryCombinedTextFrontend), ConstantVariable(str)] {} too many positional arguments
Failed to collect metadata on function, produced code may be suboptimal.  Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1676, in aot_wrapper_dedupe
    fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 607, in inner
    flat_f_outs = f(*flat_f_args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 2793, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/interpreter.py", line 271, in call_method
    return getattr(self_obj, target)(*args_tail, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: Inference tensors do not track version counter.

While executing %unsqueeze : [#users=1] = call_method[target=unsqueeze](args = (%lang_id, 0), kwargs = {})
Original traceback:
  File "IMS-Toucan/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py", line 301, in forward
    lang_id = lang_id.unsqueeze(0).to(text.device)

Looking at Pytorch's issues, it seems that currently (Pytorch 2.0.1) torch.compile() doesn't always play well with inference_mode() (but they are working on it). So I went on and blindly replaced all the inference_mode() contexts/decorators with no_grad() (the point again being in the end to benchmark and assess if performance is better than the starting point).

As a result I got the following :

[2023-08-21 08:20:06,686] torch._dynamo.symbolic_convert: [WARNING] IMS-Toucan/Preprocessing/TextFrontend.py <function english_text_expansion at 0x7fc6e0a485e0> [UserDefinedObjectVariable(ArticulatoryCombinedTextFrontend), ConstantVariable(str)] {} too many positional arguments
[2023-08-21 08:20:10,581] torch._dynamo.symbolic_convert: [WARNING] IMS-Toucan/Preprocessing/TextFrontend.py <function english_text_expansion at 0x7fc6e0a485e0> [UserDefinedObjectVariable(ArticulatoryCombinedTextFrontend), ConstantVariable(str)] {} too many positional arguments
[2023-08-21 08:20:11,183] torch._inductor.utils: [WARNING] DeviceCopy in input program
[2023-08-21 08:20:11,251] torch._inductor.utils: [WARNING] DeviceCopy in input program
/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py:90: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "IMS-Toucan/InferenceInterfaces/ToucanTTSInterface.py", line 157, in forward
    phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
  File "IMS-Toucan/InferenceInterfaces/ToucanTTSInterface.py", line 158, in <graph break in forward>
    mel, durations, pitch, energy = self.phone2mel(phones,
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "IMS-Toucan/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py", line 308, in forward
    energy_predictions = self._forward(text.unsqueeze(0),
  File "IMS-Toucan/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py", line 205, in _forward
    text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2)
  File "IMS-Toucan/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py", line 206, in <graph break in _forward>
    encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
  File "IMS-Toucan/InferenceInterfaces/InferenceArchitectures/InferenceToucanTTS.py", line 230, in <graph break in _forward>
    embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 215, in forward
    def forward(self, input):
  File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 2836, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1224, in g
    return f(*args)
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1900, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/usr/local/lib/python3.8/dist-packages/torch/_functorch/aot_autograd.py", line 1249, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/_inductor/compile_fx.py", line 248, in run
    return model(new_inputs)
  File "/tmp/torchinductor_myuser/5m/c5mnt2leqacdwn7k4pdgt76rn7k7sbrhqma5qg2iec6nhy5zaoj5.py", line 53, in call
    assert_size_stride(buf0, (1, 192, 246), (47232, 246, 1))
AssertionError: expected size 192==192, stride 1==246 at dim=1

And... this goes way beyond my understanding of how all of that works. So if anyone wants to give it a shot ;)

Flux9665 commented 12 months ago

Yes, compiling the model is somehow not compatible. I tried it on the day pytorch 2.0 was released and after a bit of trial and error I managed to compile the model. The biggest obstacle was the use of weight norm. However after the model was compiled, the forward pass produced unexpected results and basically just nonsense outputs. I'm not sure where exactly the problem lies, but I was frustrated and haven't used the compile feature since then. I was playing around with using jit to compile some components of the model, which works ok, but can sometimes also cause unexpected behaviour. Overall I decided to not pursue these compiling features further for now, because I have to manage my priorities and this would require a full re-write of many modules to at least partially compile most of the TTS.