Traceback (most recent call last):
File "/data/users/williamwen/jepa-env/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/data/users/williamwen/jepa-env/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/data/users/williamwen/jepa/app/main.py", line 65, in <module>
process_main(0, args.fname, 1, args.devices[:1])
File "/data/users/williamwen/jepa/app/main.py", line 60, in process_main
app_main(params['app'], args=params)
File "/data/users/williamwen/jepa/app/scaffold.py", line 19, in main
return importlib.import_module(f'app.{app}.train').main(
File "/data/users/williamwen/jepa/app/vjepa/train.py", line 525, in main
(loss, loss_jepa, loss_reg, _new_lr, _new_wd, grad_stats, grad_stats_pred, optim_stats,), gpu_etime_ms = train_step_timed()
File "/data/users/williamwen/jepa/app/vjepa/train.py", line 520, in train_step_timed
result = gpu_timer(train_step_opt)
File "/data/users/williamwen/jepa/src/utils/logging.py", line 24, in gpu_timer
result = closure()
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 437, in _fn
return fn(*args, **kwargs)
File "/data/users/williamwen/jepa/app/vjepa/train.py", line 476, in train_step
scaler.scale(loss).backward()
File "/data/users/williamwen/jepa/app/vjepa/train.py", line 476, in torch_dynamo_resume_in_train_step_at_476
scaler.scale(loss).backward()
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
return user_fn(self, *args)
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 880, in backward
out = call_compiled_backward()
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 829, in call_compiled_backward
out = call_func_at_runtime_with_args(
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 437, in _fn
return fn(*args, **kwargs)
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 825, in __call__
return self.get_current_callable()(inputs)
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 804, in run
return model(new_inputs)
File "/data/users/williamwen/jepa-env/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 853, in _run_from_cache
return compiled_graph.compiled_artifact(inputs)
File "/tmp/torchinductor_williamwen/zp/czpkywozkllqeyxlcpcnjgwvyqb67bycvrbmzegz3qugqlmrbejj.py", line 4497, in call
assert_size_stride(repeat_2, (8, s2, 1024), (360, 1, 0))
AssertionError: expected size 8==8, stride 424==360 at dim=0
Remove with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): (src/models/utils/modules.py). This line also causes graph breaks in eager mode.
From attempting to compile V-JEPA https://github.com/facebookresearch/jepa/tree/main
I get errors such as
In order to run the model, I had to do 2 things:
Fix input shapes (app/vjepa/train.py):
Remove
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
(src/models/utils/modules.py). This line also causes graph breaks in eager mode.To repro:
Download part of the Kinetics 400 dataset. Follow a script such as https://github.com/cvdfoundation/kinetics-dataset/blob/main/k400_downloader.sh, but you only need to download only a part of the dataset, e.g. you only need to download https://s3.amazonaws.com/kinetics/400/train/part_0.tar.gz and annotations.
Create a .csv file describing the downloaded dataset according to the V-JEPA repo instructions
Modify the config to point to the right files. Change batch size if necessary (e.g. 8 if training on single GPU).
Run
python -m app.main --fname configs/pretrain/vitl16.yaml --devices cuda:0
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire