Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

Desired inplace idioms (add yours here) #657

Open t-vi opened 2 days ago

t-vi commented 2 days ago

We do support some inplace thanks to @crcrpar 's great patches. This has enable running additional models already, but is limited to a few cases of particular interest.

However, we would want to see and prioritize other important inplace ideoms, so we would appreciate if you chimed in with needs you or your favourite model has.

Please do try to post

We don't want to support all corner cases of inplace (eg "does reshape produce a view or not") but we do care a lot about enabling users and models. Thank you!

P.S.: Also "don't need to support but should error" is a thing.

crcrpar commented 1 day ago

thunder.jit the following function with nvfuerex fails with the message below. By moving the copy for a += b to the end of a trace and replacing a += b with t0 = a + b, I expect it to work.

def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    a += b
    c = torch.exp(a)
    d = torch.tanh(b)

    e = c.view(-1)
    e.add_(d.flatten())

    d.div_(a)
    return c, d, e / 2.0
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 55, in <module>
    main()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snippet.py", line 24, in main
    c, d, e = jit_f(a, b)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 676, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 223, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 615, in get_computation_and_inputs
    thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 86, in _inplace_copy_sanity_check
    check(copy_to_out, "output")
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/transform_common.py", line 80, in check
    raise NotImplementedError(
NotImplementedError: t8 = prims.div(d, t1)  # t8: "cuda:0 f32[2, 2]" trying to use t1 (the output of 'prims.copy_') as input, which is not safe. There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.
lantiga commented 1 day ago

The failing loudly (and not obscurely) part is super important. Let's keep this issue to keep track of this aspect.

I would prioritize enumerating corner cases as much as we can, and creating good unhappy paths for them (with sane error messages). Then we address them and lift limitations where it makes sense. WDYT @crcrpar ?