pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.8k stars 22.32k forks source link

[V-JEPA] Inductor assert fail due to mismatched stride #121767

Closed williamwen42 closed 3 weeks ago

williamwen42 commented 7 months ago

From attempting to compile V-JEPA https://github.com/facebookresearch/jepa/tree/main

I get errors such as

Traceback (most recent call last):
  File "/data/users/williamwen/jepa-env/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data/users/williamwen/jepa-env/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/users/williamwen/jepa/app/main.py", line 65, in <module>
    process_main(0, args.fname, 1, args.devices[:1])
  File "/data/users/williamwen/jepa/app/main.py", line 60, in process_main
    app_main(params['app'], args=params)
  File "/data/users/williamwen/jepa/app/scaffold.py", line 19, in main
    return importlib.import_module(f'app.{app}.train').main(
  File "/data/users/williamwen/jepa/app/vjepa/train.py", line 525, in main
    (loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = train_step_timed()
  File "/data/users/williamwen/jepa/app/vjepa/train.py", line 520, in train_step_timed
    result = gpu_timer(train_step_opt)
  File "/data/users/williamwen/jepa/src/utils/logging.py", line 24, in gpu_timer
    result = closure()
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 437, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/jepa/app/vjepa/train.py", line 476, in train_step
    scaler.scale(loss).backward()
  File "/data/users/williamwen/jepa/app/vjepa/train.py", line 476, in torch_dynamo_resume_in_train_step_at_476
    scaler.scale(loss).backward()
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 880, in backward
    out = call_compiled_backward()
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 829, in call_compiled_backward
    out = call_func_at_runtime_with_args(
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 437, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 825, in __call__
    return self.get_current_callable()(inputs)
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 804, in run
    return model(new_inputs)
  File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 853, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_williamwen/zp/czpkywozkllqeyxlcpcnjgwvyqb67bycvrbmzegz3qugqlmrbejj.py", line 4497, in call
    assert_size_stride(repeat_2, (8, s2, 1024), (360, 1, 0))
AssertionError: expected size 8==8, stride 424==360 at dim=0

In order to run the model, I had to do 2 things:

Fix input shapes (app/vjepa/train.py):

            clips = torch.randn(8, 3, 16, 224, 224, dtype=torch.float, device=device)
            masks_enc = [
                torch.zeros(8, 472, dtype=torch.long, device=device),
                torch.zeros(8, 48, dtype=torch.long, device=device),
            ]
            masks_pred = [
                torch.zeros(8, 808, dtype=torch.long, device=device),
                torch.zeros(8, 1232, dtype=torch.long, device=device),
            ]clips = torch.randn(8, 3, 16, 224, 224, dtype=torch.float, device=device)
            masks_enc = [
                torch.zeros(8, 472, dtype=torch.long, device=device),
                torch.zeros(8, 48, dtype=torch.long, device=device),
            ]
            masks_pred = [
                torch.zeros(8, 808, dtype=torch.long, device=device),
                torch.zeros(8, 1232, dtype=torch.long, device=device),
            ]

Remove with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): (src/models/utils/modules.py). This line also causes graph breaks in eager mode.

To repro:

Download part of the Kinetics 400 dataset. Follow a script such as https://github.com/cvdfoundation/kinetics-dataset/blob/main/k400_downloader.sh, but you only need to download only a part of the dataset, e.g. you only need to download https://s3.amazonaws.com/kinetics/400/train/part_0.tar.gz and annotations.

Create a .csv file describing the downloaded dataset according to the V-JEPA repo instructions

Modify the config to point to the right files. Change batch size if necessary (e.g. 8 if training on single GPU).

Run python -m app.main --fname configs/pretrain/vitl16.yaml --devices cuda:0

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire

BoyuanFeng commented 3 weeks ago

Hi @williamwen42, we are scrubbing old issues. Is this still relevant?

williamwen42 commented 3 weeks ago

I don't see a stride error anymore, and I added sdpa_kernel tracing support. Can update on V-JEPA side to make it pt2 compatible.