Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k stars 63 forks source link

torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {} #664

Open wprazuch opened 2 weeks ago

wprazuch commented 2 weeks ago

🐛 Bug

There is unsupported error when running models:

for Thunder inductor for fsdp zero2/zero3:

To Reproduce

Steps to reproduce the behavior:

mkdir -p output
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864  -v $PWD/output:/output -it INTERNAL_IMAGE:pjnl-20240621

Run in the container:

torchrun --nproc-per-node=8 /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name Nous-Hermes-13b --compile thunder_inductor_cat_cudnn --distributed_mode fsdp --shard_mode zero2 

Expected behavior

The model should run or we should get OOM error.

Environment

As in the Docker image

Additional context

We reproduced for fsdp (1/2 nodes, 8 gpus), zero2/zero3. The traceback is below:

rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 639, in <module>
[rank0]:     CLI(benchmark_main)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 96, in CLI
[rank0]:     return _run_component(components, cfg_init)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/jsonargparse/_cli.py", line 196, in _run_component
[rank0]:     return component(**cfg)
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 584, in benchmark_main
[rank0]:     benchmark.train()
[rank0]:   File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 491, in train
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 522, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 288, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 306, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 599, in wrapper
[rank0]:     outputs = fn(ctx, *args)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 96, in backward
[rank0]:     grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "thunder.backward_fn_333", line 462, in backward_fn
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torch_compile.py", line 97, in compiled_func_wrapper
[rank0]:     return compiled_func(*args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
[rank0]:     return _compile(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
[rank0]:     return StrobelightCompileTimeProfiler.profile_compile_time(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:     return func(*args, **kwds)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
[rank0]:     r = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2463, in run
[rank0]:     super().run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
[rank0]:     while self.step():
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
[rank0]:     self.call_function(fn, args, {})
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2678, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2794, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
[rank0]:     while self.step():
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
[rank0]:     self.call_function(fn, args, {})
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2678, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2794, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
[rank0]:     while self.step():
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1916, in CONTAINS_OP
[rank0]:     self.push(right.call_method(self, "__contains__", [left], {}))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/user_defined.py", line 644, in call_method
[rank0]:     return super().call_method(tx, name, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py", line 320, in call_method
[rank0]:     unimplemented(f"call_method {self} {name} {args} {kwargs}")
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 221, in unimplemented
[rank0]:     raise Unsupported(msg)
[rank0]: torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}
tfogal commented 2 weeks ago
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 599, in wrapper
[rank0]:     outputs = fn(ctx, *args)
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 96, in backward
[rank0]:     grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)

Looks like we're eventually asking dynamo to do something that it cannot due to our autograd.

triage: is there something we can do to not tickle dynamo or do we need to just report this upstream?

IvanYashchuk commented 2 weeks ago

Asking dynamo to do something that it cannot due to our generated backward trace:

File "thunder.backward_fn_333", line 462, in backward_fn

and use of fullgraph=True (added in https://github.com/Lightning-AI/lightning-thunder/commit/e0ab64867a5be914d0548c195a3f850a76c8c397) https://github.com/Lightning-AI/lightning-thunder/blob/72e033a0e0dfe44d4770dec2399a9058971003ec/thunder/executors/torch_compile.py#L86 setting fullgraph=False might fix this problem.

tfogal commented 2 weeks ago

@wprazuch can I ask you to do a one-off that tests this with fullgraph=False, as Ivan points out above?

(I don't know that this the long-term solution but it will allow us to have a more reasoned discussion on the long-term solution.)

mpatel31415 commented 1 week ago

We can confirm that after the modification in torch_compile.py: compiled_func = torch.compile(trace_callable, fullgraph=False) there is no error :)

tfogal commented 1 week ago

Thanks Martyna, Wojciech!

tfogal commented 1 week ago

triage review: