young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.38k stars 254 forks source link

Serving errors: deprecated dependencies and structure error #103

Open sjw8793 opened 11 months ago

sjw8793 commented 11 months ago

When I try to serve LLaMA with v3_8 TPU as suggested in example script, there were some errors.

Environment

* TPU: `v3-8` * Software: `tpu-vm-base` **Command** ``` $ git clone https://github.com/young-geng/EasyLM $ cd EasyLM $ ./scripts/tpu_vm_setup.sh $ $ python -m EasyLM.models.llama.llama_train \ --mesh_dim='1,-1,1' \ --dtype='bf16' \ --total_steps=500 \ --log_freq=50 \ --load_llama_config='1b' \ --update_llama_config='' \ --load_dataset_state='' \ --load_checkpoint='' \ --save_model_freq=100 \ --tokenizer.vocab_file='/path/to/tokenizer.model' \ --optimizer.type='adamw' \ --optimizer.adamw_optimizer.weight_decay=0.1 \ --optimizer.adamw_optimizer.lr=1e-3 \ --optimizer.adamw_optimizer.end_lr=1e-4 \ --optimizer.adamw_optimizer.lr_warmup_steps=10 \ --optimizer.adamw_optimizer.lr_decay_steps=100 \ --train_dataset.type='json' \ --train_dataset.text_processor.fields='text' \ --train_dataset.json_dataset.path='/path/to/dataset.jsonl' \ --train_dataset.json_dataset.seq_length=1024 \ --train_dataset.json_dataset.batch_size=64 \ --train_dataset.json_dataset.tokenizer_processes=1 \ --checkpointer.save_optimizer_state=True \ --checkpointer.float_dtype=bf16 \ --logger.online=False \ --logger.output_dir="~/ellama_checkpoints/" \ |& tee $HOME/output1107_wiki.txt $ $ python -m EasyLM.models.llama.llama_serve \ --load_llama_config='1b' \ --load_checkpoint="params::/path/to/streaming_train_state" \ --tokenizer.vocab_file='/path/to/tokenizer.model' \ --mesh_dim='1,-1,1' \ --dtype='bf16' \ --input_length=1024 \ --seq_length=2048 \ --lm_server.batch_size=4 \ --lm_server.port=8888 \ --lm_server.pre_compile='all' ```

1. Deprecation warning

ImportError: cannot import name 'soft_unicode' from 'markupsafe' ImportError: Pandas requires version '3.0.0' or newer of 'jinja2'

These can be solved by adding 2 lines to tpu_requirements.txt

markupsafe==2.0.1
jinja2~=3.0.0


DeprecationWarning: concurrency_count has been deprecated. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via max_threads in launch().

I was able to solve this by deleting concurrency_count=1 in serving.py, line 403. According to Gradio v4.0.0 changelog, concurrency_count is removed and can be replaced with concurrency_limit. As I'm not exactly understanding what it supposed to do and it's set to 1 by default, I just removed it.

2. Structure error

However, when I solve deprecation errors above, this error appears:

Error Log

``` I1107 06:16:48.996244 140573565926464 mesh_utils.py:260] Reordering mesh to physical ring order on single-tray TPU v2/v3. $HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name user_fn already exists, using user_fn_1 warnings.warn(f"api_name {api_name} already exists, using {api_name_}") $HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_1 warnings.warn(f"api_name {api_name} already exists, using {api_name_}") $HOME/.local/lib/python3.8/site-packages/gradio/blocks.py:889: UserWarning: api_name model_fn already exists, using model_fn_2 warnings.warn(f"api_name {api_name} already exists, using {api_name_}") Traceback (most recent call last): File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in mlxu.run(main) File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main server.run() File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run self.loglikelihood(pre_compile_data, pre_compile_data) File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood loglikelihood, is_greedy, sharded_rng = forward_loglikelihood( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 775, in infer_params return common_infer_params(pjit_info_args, *args, **kwargs) File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun ans = call(fun, *args) File "$HOME/.local/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers_) File "$HOME/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood logits = hf_model.module.apply( File "$HOME/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "$HOME/.local/lib/python3.8/site-packages/flax/linen/module.py", line 1511, in apply return apply( File "$HOME/.local/lib/python3.8/site-packages/flax/core/scope.py", line 930, in wrapper raise errors.ApplyScopeInvalidVariablesStructureError(variables) jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError) The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 386, in mlxu.run(main) File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "$HOME/.local/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 382, in main server.run() File "$HOME/EasyLM/EasyLM/serving.py", line 417, in run self.loglikelihood(pre_compile_data, pre_compile_data) File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 208, in loglikelihood loglikelihood, is_greedy, sharded_rng = forward_loglikelihood( File "$HOME/EasyLM/EasyLM/models/llama/llama_serve.py", line 88, in forward_loglikelihood logits = hf_model.module.apply( ```

flax.errors.ApplyScopeInvalidVariablesStructureError: Expect the variables (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e. {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)

It seems like something went wrong with "params" loading at function load_trainstate_checkpoint in checkpoint.py, but I couldn't figure where. Is there someone who knows what's wrong?

sjw8793 commented 11 months ago

There was some misunderstanding; I should have used trainstate_params instead of params in my case. So, the serving script should be like below:

$ python -m EasyLM.models.llama.llama_serve \
    --load_llama_config='1b' \
    --load_checkpoint="trainstate_params::/path/to/streaming_train_state" \
    --tokenizer.vocab_file='/path/to/tokenizer.model' \
    --mesh_dim='1,-1,1' \
    --dtype='bf16' \
    --input_length=1024 \
    --seq_length=2048 \
    --lm_server.batch_size=4 \
    --lm_server.port=8888 \
    --lm_server.pre_compile='all'
sjw8793 commented 11 months ago

Sorry for reopen, I thought it'd be better to keep this opened until the dependency deprecation is solved.