Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.12k stars 69 forks source link

`test_vjp_correctness` fails with ops that return tensors that do not require grads. #120

Open nikitaved opened 5 months ago

nikitaved commented 5 months ago

🐛 Bug

As per title. To reproduce, one could uncomment these tests in these tests in https://github.com/Lightning-AI/lightning-thunder/pull/118 to get:

thunder/tests/test_grad.py:423: in test_vjp_correctness                                                                                                                                                                                       
    result = run_snippet(                                                                                                                                                                                                                     
thunder/tests/framework.py:483: in run_snippet                                                                                                                                                                                                
    raise ex                                                                                                                                                                                                                                  
thunder/tests/framework.py:475: in run_snippet                                                                                                                                                                                                
    snippet(*args, **kwargs)                                                                                                                                                                                                                  
thunder/tests/test_grad.py:394: in snippet_vjp_correctness                                                                                                                                                                                    
    check_vjp(func, *args, executor=executor)                                                                                                                                                                                                 
thunder/tests/test_grad.py:304: in check_vjp                                                                                                                                                                                                  
    _, J_star_v = executor.make_callable_legacy(vjp(f), disable_torch_autograd_support=True)(primals, v)                                                                                                                                      
thunder/common.py:783: in _fn                                                                                                                                                                                                                 
    trc_or_result = trace(compile_data=cd)(processed_function, *args, **kwargs)                                                                                                                                                               
thunder/core/interpreter.py:1298: in fn_                                                                                                                                                                                                      
    return fn(*args, **kwargs)                                                                                                                                                                                                                
thunder/common.py:534: in _trace                                                                                                                                                                                                              
    result = fn(*proxyargs, **proxykwargs)                                                                                                                                                                                                    
thunder/core/transforms.py:3629: in _vjp                                                                                                                                                                                                      
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)                                                                                                                                                                         
thunder/core/transforms.py:3603: in vjp_call_metafunc                                                                                                                                                                                         
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)                                                                                                                                                                     
thunder/core/transforms.py:3414: in augmented_forward_pass                                                                                                                                                                                    
    result, env = eval_trace(                                                                                                                                                                                                                 
thunder/core/transforms.py:1693: in eval_trace                                                                                                                                                                                                
    prim_func = symbol_mapper(symbol)                                                                                                                                                                                                         
thunder/core/transforms.py:3338: in vjp_symbol_mapper                                                                                                                                                                                         
    vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)                                                                                                                                                                             
thunder/core/vjp_utils.py:99: in make_aug_forward_and_backward                                                                                                                                                                                
    backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0])                                                                                                                                    
thunder/core/utils.py:1062: in find_producer_symbols                                                                                                                                                                                          
    if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

x = None                                                                                                                                                                                                                                      

>   if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:                                                                                                                                                          
E   AttributeError: 'NoneType' object has no attribute 'name'                                                                                                                                                                                 

thunder/core/utils.py:1062: AttributeError     
kshitij12345 commented 4 months ago

The root cause seems to be in vjp itself.

import thunder
import torch

def foo(x):
    return thunder.torch.topk(x, k=2)

x = torch.ones(3, 3) * 2
co_x = torch.ones(3, 3)
outputs = torch.topk(x, k=2)
cotangents = tuple(torch.ones_like(x) for x in outputs)
vjp_foo = thunder.core.transforms.vjp(foo)
jfoo = thunder.compile(vjp_foo, disable_preprocessing=True)
# jfoo = thunder.jit(vjp_foo)  # Doesn't work currently.

# Fails with 
# File "/home/kkalambarkar/lightning-thunder/thunder/core/utils.py", line 1062, in <lambda>
#     if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:
# AttributeError: 'NoneType' object has no attribute 'name'
jfoo(primals=(x,), cotangents=cotangents)

NOTE: Currently the test uses make_callable_legacy (which uses thunder.compile). We should probably wait till thunder.jit(vjp(fn)) is supported and then verify. (Related issue: https://github.com/Lightning-AI/lightning-thunder/issues/198)