pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.14k stars 22.42k forks source link

Error while exporting BLIP #138111

Open agunapal opened 4 days ago

agunapal commented 4 days ago

🐛 Describe the bug

Steps:

1) pip install timm==0.4.12 fairscale transformers

2) git clone https://github.com/salesforce/BLIP.git

Code for reproducing the problem

import torch

from models.blip import blip_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 384
image = torch.randn(1, 3,384,384)
caption_input = ""

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'

model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)

Error

  File "/home/ubuntu/experiments/export/BLIP/image_captioning.py", line 17, in <module>
    exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/__init__.py", line 366, in export
    return _export(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 1014, in wrapper
    raise e
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 987, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/exported_program.py", line 116, in wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 1964, in _export
    export_artifact = export_func(  # type: ignore[operator]
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 1754, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 643, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 1684, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 762, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 112, in aot_dispatch_export
    graph, _, _ = aot_dispatch_base_graph(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 136, in aot_dispatch_base_graph
    fw_module = _create_graph(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 55, in _create_graph
    fx_g = make_fx(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2147, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2085, in trace
    return self._trace_inner(f, *args)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 2056, in _trace_inner
    t = dispatch_trace(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 654, in _fn
    return fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1133, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1652, in trace
    res = super().trace(root, concrete_args)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 827, in trace
    (self.create_arg(fn(*args)),),
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1188, in wrapped
    out = f(*tensors)
  File "<string>", line 1, in <lambda>
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 693, in inner_fn
    outs = fn(*args)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 413, in _functionalized_f_helper
    f_outs = fn(*f_args)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 78, in inner_fn
    outs = fn(*args)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 805, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1722, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 523, in call_module
    ret_val = forward(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 798, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/export/_trace.py", line 1671, in forward
    tree_out = mod(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 805, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1722, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 523, in call_module
    ret_val = forward(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 798, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/experiments/export/BLIP/models/blip.py", line 112, in forward
    text.input_ids[:,0] = self.tokenizer.bos_token_id
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1236, in __torch_function__
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1274, in __torch_function__
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_export/non_strict_utils.py", line 551, in __torch_function__
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_ops.py", line 840, in handler
    return torch._library.utils.handle_dispatch_mode(
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_library/utils.py", line 284, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  File "/home/ubuntu/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 541, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage

Versions

Collecting environment information...
PyTorch version: 2.6.0.dev20241015+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:24:10) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1063-aws-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA A10G
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             8
On-line CPU(s) list:                0-7
Thread(s) per core:                 2
Core(s) per socket:                 4
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          AuthenticAMD
CPU family:                         23
Model:                              49
Model name:                         AMD EPYC 7R32
Stepping:                           0
CPU MHz:                            2799.998
BogoMIPS:                           5599.99
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          128 KiB
L1i cache:                          128 KiB
L2 cache:                           2 MiB
L3 cache:                           16 MiB
NUMA node0 CPU(s):                  0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries:
[pip3] numpy==2.0.2
[pip3] torch==2.6.0.dev20241015+cpu
[pip3] torchvision==0.20.0.dev20241015+cpu
[conda] numpy                     2.0.2                    pypi_0    pypi
[conda] torch                     2.6.0.dev20241015+cpu          pypi_0    pypi
[conda] torchvision               0.20.0.dev20241015+cpu          pypi_0    pypi

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

agunapal commented 4 days ago

This is similar to https://github.com/pytorch/pytorch/issues/127571

This is resolved by adding text.input_ids = text.input_ids.clone() in https://github.com/salesforce/BLIP/blob/main/models/blip.py#L111