Open scsonic opened 7 months ago
My initial thought is adding dummy values or performing some sort of interpolation...one hot encoding. Are the dimensions strides?
Hello, has this issue been resolved? I've also encountered a similar problem. Additionally, when I increase the frame rate, it throws a shape error. Besides, I feel like I'm not using the GPU.
Create one image png: success Create video: fail A100 80G
update --n_frames=2048 always fail
!JAX_TRACEBACK_FILTERING=off python3 -u -m lwm.vision_generation \ --prompt={prompt} \ --output_file={output_filename} \ --temperature_image=1.0 \ --top_k_image=8192 \ --cfg_scale_image=5.0 \ --vqgan_checkpoint="{vqgan_checkpoint}" \ --n_frames=2048 \ --dtype='fp32' \ --load_llama_config='7b' \ --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \ --load_checkpoint="params::{lwm_checkpoint}" \ --tokenizer.vocab_file="{llama_tokenizer_path}"
the output /tmp/notebook/content/LWM env: PYTHONPHAT=/tmp/notebook/content/LWM env: NUMEXPR_MAX_THREADS=12 I0317 09:57:37.786386 139819503906816 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA I0317 09:57:37.787172 139819503906816 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory 2024-03-17 09:57:38.663241: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages. 100%|█████████████████████████████████████████████| 1/1 [00:24<00:00, 24.97s/it] 0%| | 0/1 [00:00<?, ?it/s] Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 258, in
run(main)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 247, in main
videos.extend(generate_video_pred(prompts, images, max_input_length=128))
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 215, in generate_video_pred
output, sharded_rng = _sharded_forward_generate(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
argsflat, , params, in_tree, outtree, , , = infer_params_fn(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
return common_infer_params(pjit_info_args, *args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, args)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(intracers)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 116, in _forward_generate
output = model.generate_vision(
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 710, in generate_vision
return self._sample_vision(
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 515, in _sample_vision
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, *model_kwargs)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 453, in prepare_inputs_for_generation
past_key_values = self.init_cache(batch_size, max_length)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 151, in init_cache
init_variables = self.module.init(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2319, in init
_, v_out = self.init_with_output(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2215, in init_with_output
return init_with_output(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/scope.py", line 1137, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/scope.py", line 1101, in wrapper
y = fn(root, *args, *kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2972, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, kwargs)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 396, in call
outputs = self.transformer(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, *kwargs)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 315, in call
outputs = self.h(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, args, kwargs)
File "/tmp/notebook/content/LWM/lwm/llama.py", line 981, in call
hiddenstates, = nn.scan(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 378, in wrapped_fn
ret = trafo_fn(module_scopes, *args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 325, in wrapper
y, out_variable_groups_xs_t = fn(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 1024, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 148, in scanfn
, outpvals, = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, *kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 774, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 120, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, xs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 1005, in scanned
c, y = fn(scope, c, args)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 370, in core_fn
res = fn(cloned, *args, *kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, args, kwargs)
File "/tmp/notebook/content/LWM/lwm/llama.py", line 757, in call
attn_outputs = self.attention(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/tmp/notebook/content/LWM/lwm/llama.py", line 627, in call
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to
return util._broadcast_to(array, shape)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 428, in _broadcast_to
raise ValueError(msg.format(arr_shape, shape))
ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 526464) and requested shape (2, 1, 32768, 32768)