Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.13k stars 69 forks source link

`TensorBase.repeat_interleave` #173

Open tfogal opened 5 months ago

tfogal commented 5 months ago

🚀 Feature

Implement TensorBase.repeat_interleave

Motivation

NeMo text-to-image model.

cc @tfogal

kshitij12345 commented 4 months ago

Which overload of torch.repeat_interleave does the model use? The overload with tensor will be data-dependent. There is an argument output_size which can be used to specify the output shape and it will make the op data-independent. (Maybe the model could be re-written to use this argument?)

>>> torch.repeat_interleave(torch.tensor([1, 2]), repeats=3)  # output shape depends on repeats (if 3 is constant then it is ok, if it is symbolic the value is data-dependent)
tensor([1, 1, 1, 2, 2, 2])
>>> torch.repeat_interleave(torch.tensor([1, 2, 1]))  # data-dependent.
tensor([0, 1, 1, 2])
>>> torch.repeat_interleave(torch.tensor([1, 2, 1]), repeats=torch.tensor([1, 1, 2]))  # data-dependent.
tensor([1, 2, 1, 1])
tfogal commented 4 months ago

Which overload of torch.repeat_interleave does the model use?

Looks like they are data-dependent, mostly:

https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/models/classification_models.py#L1170 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/models/classification_models.py#L1177 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/modules/audio_modules.py#L677 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/modules/squeezeformer_encoder.py#L356 <-- constant! https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/k2/graph_transducer.py#L192 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/k2/loss_mixins.py#L144 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/k2/loss_mixins.py#L162 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py#L336 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/utils/offline_clustering.py#L480 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/utils/offline_clustering.py#L526 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/asr/parts/utils/vad_utils.py#L1527 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention.py#L225 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention.py#L277 https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/multimodal/modules/nerf/loss/normal_consistency_loss.py#L59 <-- constant! https://github.com/NVIDIA/NeMo/blob/8c8c667405cc046c15e180682fbd47de3195b8af/nemo/collections/tts/modules/aligner.py#L160

As above, it looks like most of them are data-dependent. Maybe we should graph-break at those ops?

There could be some limited value in a limited implementation that implements the constant cases and just errors out in the non-constant cases; we might be able to get through another NeMo model or two that way, at least.

(Maybe the model could be re-written to use [the output_size] argument?)

That would be great! Seems likely to benefit even non-thunder use cases as well. This might be beyond my torch-fu however. Help welcome, anyone!

mruberry commented 4 months ago

Marking "triage review" to update the team on the latest thinking for NeMo-related ops

mruberry commented 4 months ago

triage review —

are the ranks actually known by the practitioner in advance? (which would allow us to specify the output shape at compile-time)

could we upstream a perf-improvement to specify the output shape in stock nemo?