Traceback (most recent call last):
File "/opt/conda/bin/litgpt", line 8, in <module>
sys.exit(main())
File "litgpt/__main__.py", line 57, in main
CLI(parser_data)
File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
return _run_component(component, init.get(subcommand))
File "/opt/conda/lib/python3.10/site-packages/jsonargparse/_cli.py", line 196, in _run_component
return component(**cfg)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "litgpt/generate/base.py", line 252, in main
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "litgpt/generate/base.py", line 127, in generate
token = next_token(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "litgpt/generate/base.py", line 74, in next_token
logits = model(x, input_pos)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 138, in forward
with precision.forward_context():
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
result = inner_convert(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
assert (
AssertionError: Global state changed while dynamo tracing, please report a bug
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Sorry that I can't be more helpful here, but I never used compilation in LitGPT myself, but I remember from my colleagues that torch.compile does not fully support everything and there are some issues with it.
CLI command:
$ litgpt generate stabilityai/stablelm-base-alpha-3b --prompt "Hello, my name is" --compile true
Workaround:
torch._dynamo.config.suppress_errors = True
togenerate/base.py
didn't work.torch.compile(next_token, ...)
withmodel = torch.compile(model)
beforemodel = fabric.setup_module(model)
worksEnv:
Full error: