LargeWorldModel / LWM

Apache License 2.0
7k stars 541 forks source link

Error while running bash command: run_sample_video.sh | Error: "TypeError: missing a required argument: 'segment_ids'" #77

Open samitm-123 opened 2 weeks ago

samitm-123 commented 2 weeks ago

I receive this error when i run this bash command: !bash LWM/scripts/run_sample_video.sh. I have followed all the direction listed in the repo.

/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/LWM/lwm/vision_generation.py", line 256, in <module>
    run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/content/LWM/lwm/vision_generation.py", line 92, in main
    model = FlaxVideoLLaMAForCausalLM(
  File "/content/LWM/lwm/vision_llama.py", line 141, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flax_utils.py", line 224, in __init__
    params_shape_tree = jax.eval_shape(init_fn, self.key)
  File "/content/LWM/lwm/vision_llama.py", line 166, in init_weights
    random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
  File "/content/LWM/lwm/vision_llama.py", line 396, in __call__
    outputs = self.transformer(
  File "/content/LWM/lwm/vision_llama.py", line 315, in __call__
    outputs = self.h(
  File "/content/LWM/lwm/llama.py", line 945, in __call__
    hidden_states, _ = nn.scan(
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 151, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 123, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/content/LWM/lwm/llama.py", line 724, in __call__
    attn_outputs = self.attention(
  File "/content/LWM/lwm/llama.py", line 615, in __call__
    attn_output = ring_attention_sharded(
  File "/usr/lib/python3.10/inspect.py", line 3186, in bind
    return self._bind(args, kwargs)
  File "/usr/lib/python3.10/inspect.py", line 3101, in _bind
    raise TypeError(msg) from None
TypeError: missing a required argument: 'segment_ids'

Would appreciate some help here.

gabeweisz commented 2 weeks ago

Seeing the same error. Commit 97ae4b672f0a9d8bc30ab536d4bac42c3d044aff works for me on GPU

samitm-123 commented 2 weeks ago

@gabeweisz I get the following error: (lwm) madhu@madhupc:~/LWM$ bash scripts/run_sample_image.sh Traceback (most recent call last): File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in runcode exec(code, runglobals) File "/home/madhu/LWM/lwm/visiongeneration.py", line 11, in from tux import ( File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/_init.py", line 1, in from .checkpoint import StreamingCheckpointer File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 4, in import flax File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/__init.py", line 18, in from .configurations import ( File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/configurations.py", line 92, in flax_filter_frames = define_bool_state( File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/configurations.py", line 42, in define_bool_state return jax_config.define_bool_state('flax' + name, default, help) AttributeError: 'Config' object has no attribute 'define_bool_state'

This is our google colab, do you mind taking a look and telling us changes should be made to run this model.

https://colab.research.google.com/drive/1Bx-wRzOspvq5JLctNKRHwHq-vIgw7wlv?usp=sharing

gabeweisz commented 2 weeks ago

For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86

I'm not part of your google collaboration, but maybe the authors of this project will chime in with more information

madhuvanthp commented 2 weeks ago

For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86

I'm not part of your google collaboration, but maybe the authors of this project will chime in with more information

So for gpus you used Commit 97ae4b6 and solely followed the instructions for that specific version? Or did you run some other commands? Also, do you mind showing me your entire requirements txt file? The versions in the requirements.txt from 97ae4b6 are different from what you mentioned. I am struggling to get this working with my gpu.

gabeweisz commented 2 weeks ago

I used commit 97ae4b6 and did not change anything.

I installed packages using the requirements.txt in that commit, and then updated the two packages that I mention above manually using pip.

I most likely have a different GPU than you do, but this is what worked for me.