mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.93k stars 959 forks source link

Error when using torchcompile option for CLIP training #726

Open kkjh0723 opened 11 months ago

kkjh0723 commented 11 months ago

Hello,

While I attempt to apply torchcompile option for training CLIP ViT-B-32 model, I got some error. Below is the script to run training.

torchrun --nproc_per_node 16 -m training.main --save-frequency 1 --zeroshot-frequency 1 --report-to tensorboard --train-data={data_dir}  --csv-img-key filepath --csv-caption-key title --imagenet-val={imagenet val dir} --workers=8 --model ViT-B-32 --precision amp_bf16 --workers 4 --csv-separator "," --local-loss --gather-with-grad --aug-cfg scale='(0.5, 1.0)' --name test--accum-freq 4 --grad-checkpointing --torchcompile

And I got the below error message. How can I fix this issue? Note that my pytorch version is 2.1.0 and no error occurs when I runs above script without --torchcompile option.

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/open_clip/src/training/main.py", line 508, in <module>
    main(sys.argv[1:])
  File "/workspace/open_clip/src/training/main.py", line 436, in main
    train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)
  File "/workspace/open_clip/src/training/train.py", line 117, in train_one_epoch
    model_out = model(images, texts)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 487, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 586, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1167, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1123, in call_function
    p_args, _, example_value = self.create_wrapped_node(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1025, in create_wrapped_node
    ) = speculate_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 203, in speculate_subgraph
    f"to trace function `{f.get_name()}` into a single graph. This means "
torch._dynamo.exc.InternalTorchDynamoError: 'NNModuleVariable' object has no attribute 'get_name'

from user code:
   File "/workspace/open_clip/src/open_clip/model.py", line 293, in forward
    image_features = self.encode_image(image, normalize=True) if image is not None else None
  File "/workspace/open_clip/src/open_clip/model.py", line 266, in encode_image
    features = self.visual(image)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/open_clip/src/open_clip/transformer.py", line 516, in forward
    x = self.transformer(x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/open_clip/src/open_clip/transformer.py", line 322, in forward
    x = checkpoint(r, x, None, None, attn_mask)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
rwightman commented 11 months ago

@kkjh0723 I think it might break with gradient checkpointing? not sure there is a workaround, possibly maybe using non reentrant mode?

EIFY commented 11 months ago

I got the same error trying to run both --grad-checkpointing and --torchcompile, but since pytorch 2.1.0 --torchcompile now works with --accum-freq > 1 as the next best option.

rwightman commented 11 months ago

@EIFY did you try forcing the non reentrant checkpointing? could look to change the default if that works...

EIFY commented 11 months ago

@rwightman No I haven't tried that.

In that regard, the good news is that https://github.com/mlfoundations/open_clip/blob/91923dfc376afb9d44577a0c9bd0930389349438/src/open_clip/transformer.py#L320-L322 https://github.com/pytorch/pytorch/issues/79887 is now fixed and we should be able to do e.g.

if self.grad_checkpointing and not torch.jit.is_scripting(): 
    x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)

The bad news is that other than that grad_checkpointing is either delegated to the vision/text trunks w/o argument support https://github.com/mlfoundations/open_clip/blob/91923dfc376afb9d44577a0c9bd0930389349438/src/open_clip/model.py#L260-L263 or not supported at all: https://github.com/mlfoundations/open_clip/blob/91923dfc376afb9d44577a0c9bd0930389349438/src/open_clip/modified_resnet.py#L161-L164 So fairly involved changes would be necessary. I will try doing the easy part and see if it at least gets past that when I get a chance.

EIFY commented 11 months ago

@rwightman OK so it turned out that use_reentrant=False doesn't help. It still breaks at the same point:

[2023-11-08 12:56:29,383] [0/0] torch._utils_internal: [INFO] CompilationMetrics(frame_key='1', co_name='forward', co_filename='/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py', co_firstlineno=256, cache_size=0, guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, fail_reason="'NNModuleVariable' object has no attribute 'get_name'")
Traceback (most recent call last):
(...)
torch._dynamo.exc.InternalTorchDynamoError: 'NNModuleVariable' object has no attribute 'get_name'

from user code:
   File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py", line 274, in forward
    image_features = dim_scale_img * self.encode_image(image, normalize=self.normalize) if image is not None else None
  File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py", line 239, in encode_image
    features = self.visual(image)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/transformer.py", line 486, in forward
    x = self.transformer(x)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/transformer.py", line 319, in forward
    x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
lavoiems commented 9 months ago

Is there any update on this? I am facing the same issue.