LargeWorldModel / LWM

Large World Model -- Modeling Text and Video with Millions Context
https://largeworldmodel.github.io/
Apache License 2.0
7.14k stars 551 forks source link

ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 526464) and requested shape (2, 1, 32768, 32768) #62

Open scsonic opened 7 months ago

scsonic commented 7 months ago

Create one image png: success Create video: fail A100 80G

update --n_frames=2048 always fail

!JAX_TRACEBACK_FILTERING=off python3 -u -m lwm.vision_generation \ --prompt={prompt} \ --output_file={output_filename} \ --temperature_image=1.0 \ --top_k_image=8192 \ --cfg_scale_image=5.0 \ --vqgan_checkpoint="{vqgan_checkpoint}" \ --n_frames=2048 \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \ --load_checkpoint="params::{lwm_checkpoint}" \ --tokenizer.vocab_file="{llama_tokenizer_path}"

the output /tmp/notebook/content/LWM env: PYTHONPHAT=/tmp/notebook/content/LWM env: NUMEXPR_MAX_THREADS=12 I0317 09:57:37.786386 139819503906816 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA I0317 09:57:37.787172 139819503906816 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory 2024-03-17 09:57:38.663241: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages. 100%|█████████████████████████████████████████████| 1/1 [00:24<00:00, 24.97s/it] 0%| | 0/1 [00:00<?, ?it/s] Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 258, in run(main) File "/home/everlyai-user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/everlyai-user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 247, in main videos.extend(generate_video_pred(prompts, images, max_input_length=128)) File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 215, in generate_video_pred output, sharded_rng = _sharded_forward_generate( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper argsflat, , params, in_tree, outtree, , , = infer_params_fn( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params return common_infer_params(pjit_info_args, *args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun ans = call(fun, args) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper return func(args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 116, in _forward_generate output = model.generate_vision( File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 710, in generate_vision return self._sample_vision( File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 515, in _sample_vision model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, *model_kwargs) File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 453, in prepare_inputs_for_generation past_key_values = self.init_cache(batch_size, max_length) File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 151, in init_cache init_variables = self.module.init( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2319, in init _, v_out = self.init_with_output( File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2215, in init_with_output return init_with_output( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/scope.py", line 1137, in wrapper return apply(fn, mutable=mutable, flags=init_flags)( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/scope.py", line 1101, in wrapper y = fn(root, *args, *kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2972, in scope_fn return fn(module.clone(parent=scope, _deep_clone=True), args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method y = run_fun(self, *args, kwargs) File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 396, in call outputs = self.transformer( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method y = run_fun(self, *args, *kwargs) File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 315, in call outputs = self.h( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method y = run_fun(self, args, kwargs) File "/tmp/notebook/content/LWM/lwm/llama.py", line 981, in call hiddenstates, = nn.scan( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 378, in wrapped_fn ret = trafo_fn(module_scopes, *args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 325, in wrapper y, out_variable_groups_xs_t = fn( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 1024, in inner broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 148, in scanfn , outpvals, = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper return func(*args, *kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 774, in trace_to_jaxpr_nounits jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 120, in body_fn broadcast_out, c, ys = fn(broadcast_in, c, xs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 1005, in scanned c, y = fn(scope, c, args) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 370, in core_fn res = fn(cloned, *args, *kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method y = run_fun(self, args, kwargs) File "/tmp/notebook/content/LWM/lwm/llama.py", line 757, in call attn_outputs = self.attention( File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method return self._call_wrapped_method(fun, args, kwargs) File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method y = run_fun(self, *args, **kwargs) File "/tmp/notebook/content/LWM/lwm/llama.py", line 627, in call attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to return util._broadcast_to(array, shape) File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 428, in _broadcast_to raise ValueError(msg.format(arr_shape, shape)) ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 526464) and requested shape (2, 1, 32768, 32768)

OkinoLeiba commented 7 months ago

My initial thought is adding dummy values or performing some sort of interpolation...one hot encoding. Are the dimensions strides?

ZQpengyu commented 6 months ago

Hello, has this issue been resolved? I've also encountered a similar problem. Additionally, when I increase the frame rate, it throws a shape error. Besides, I feel like I'm not using the GPU.