Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.08k stars 62 forks source link

Support missing interpolate() options #680

Open tfogal opened 1 week ago

tfogal commented 1 week ago

🚀 Feature

Basically it seems that

Upsample(scale_factor=2.0, mode='nearest')

invokes interpolate under the hood, and includes some options that we are currently missing.

Motivation

This is from a proxy model of a customer NVIDIA would like to support.

Additional context

Full test script:

import torch
from torch.nn import BatchNorm2d, Upsample

class ProxyModel(torch.nn.Module):
    def __init__(self, _ignored=False):
        super().__init__()

        self.bn0 = BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.upsample0 = Upsample(scale_factor=2.0, mode='nearest')

    def forward(self):
        i0 = torch.randn((22,288,15,20), dtype=torch.float16, device="cuda:0")
        i1 = self.bn0(i0)
        del i0
        i2 = torch.nn.functional.relu6(i1, inplace = False)
        del i1
        i3 = self.upsample0(i2)
        return i3

m = ProxyModel().cuda()
import thunder
mdl = thunder.jit(m)
output = mdl()

This will produce this traceback:

Traceback (most recent call last):
  File "/home/tfogal/xxx/driver.py", line 39, in <module>
    output = mdl(input_data)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/module.py", line 61, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 676, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 223, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 502, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 211, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1719, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6696, in fn_
    raise e
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6664, in fn_2
    return fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/home/tfogal/xxx/driver.py", line 18, in forward
    i3 = self.upsample0(i2)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6061, in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/upsampling.py", line 157, in forward
    return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,
  File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1273, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 703, in wrapper
    return fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/symbol.py", line 268, in __call__
    result = self.meta(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 132, in _fn
    result = fn(*args, **kwargs)
  File "/home/tfogal/dev/thunder/thunder/torch/__init__.py", line 4558, in interpolate
    utils.check(align_corners == False,
  File "/home/tfogal/dev/thunder/thunder/core/baseutils.py", line 103, in check
    raise exception_type(s())
RuntimeError: 'align_corners=True' is not yet supported.

cc @apaz-cli

tfogal commented 1 week ago

triage review:

tfogal commented 4 days ago

Soooo this might actually be fixed by #679. I just re-tried today and the script runs without error.

My guess as to what happened: the original utils.check was for align_corners==False, but after running CI I realized that the default is actually None. The committed version properly defaults to None and checks for that.

Anyway, leaving it open + on @kshitij12345 to verify that it's not just my environment :-)