mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.26k stars 1.58k forks source link

[Bug] gptj is not a supported model architecture for WebGPU? #457

Closed jparismorgan closed 1 year ago

jparismorgan commented 1 year ago

🐛 Bug

Based on the documentation it seems that gptj should be a supported model architecture. But when I try to build https://huggingface.co/EleutherAI/gpt-j-6b for WebGPU I get AssertionError: Model type gptj not supported.

To Reproduce

Steps to reproduce the behavior:

  1. Run python3 build.py --hf-path EleutherAI/gpt-j-6b --target webgpu --quantization q4f32_0 and get error:
    (mlc-llm) ~/repo/mlc-llm python3 build.py --hf-path EleutherAI/gpt-j-6b --target webgpu --quantization q4f32_0
    Weights exist at dist/models/gpt-j-6b, skipping download.
    Traceback (most recent call last):
    File "/Users/parismorgan/repo/mlc-llm/build.py", line 431, in <module>
    ARGS = _parse_args()
    File "/Users/parismorgan/repo/mlc-llm/build.py", line 76, in _parse_args
    parsed = _setup_model_path(parsed)
    File "/Users/parismorgan/repo/mlc-llm/build.py", line 147, in _setup_model_path
    validate_config(args.model_path)
    File "/Users/parismorgan/repo/mlc-llm/build.py", line 196, in validate_config
    assert (
    AssertionError: Model type gptj not supported.
  2. To work around that I tried to add it to supported_model_types = set(["llama", "gpt_neox", "moss", "rwkv"]), i.e. supported_model_types = set(["llama", "gpt_neox", "moss", "rwkv", "gptj"])
  3. But when I run again I get:
    (mlc-llm) ~/repo/mlc-llm python3 build.py --hf-path EleutherAI/gpt-j-6b --target webgpu --quantization q4f32_0                                                                                                         
    Weights exist at dist/models/gpt-j-6b, skipping download.
    Using path "dist/models/gpt-j-6b" for model "gpt-j-6b"
    Database paths: ['log_db/rwkv-raven-3b', 'log_db/redpajama-3b-q4f16', 'log_db/redpajama-3b-q4f32', 'log_db/rwkv-raven-1b5', 'log_db/dolly-v2-3b', 'log_db/rwkv-raven-7b', 'log_db/vicuna-v1-7b']
    Target configured: webgpu -keys=webgpu,gpu -max_num_threads=256
    Traceback (most recent call last):
    File "/Users/parismorgan/repo/mlc-llm/build.py", line 431, in <module>
    ARGS = _parse_args()
    File "/Users/parismorgan/repo/mlc-llm/build.py", line 110, in _parse_args
    utils.argparse_postproc_common(parsed)
    File "/Users/parismorgan/repo/mlc-llm/mlc_llm/utils.py", line 84, in argparse_postproc_common
    raise ValueError(
    ValueError: Cannot recognize model "gpt-j-6b". Supported ones: vicuna-, dolly-, stablelm-, redpajama-, moss-, open_llama, rwkv-, gorilla-

    This is because it's not a supported model prefix:

    supported_model_prefix = {
        "vicuna-": ("vicuna_v1.1", "llama"),
        "dolly-": ("dolly", "gpt_neox"),
        "stablelm-": ("stablelm", "gpt_neox"),
        "redpajama-": ("redpajama_chat", "gpt_neox"),
        "moss-": ("moss", "moss"),
        "open_llama": ("LM", "llama"),
        "rwkv-": ("rwkv", "rwkv"),
        "gorilla-": ("gorilla", "llama"),
    }

    I stopped here because I wasn't sure if just adding in something made sense.

Expected behavior

Perhaps the docs could use an update on what is supported when building for WebGPU? Or am I just misunderstanding how things work and they are clear? Or perhaps it is expected that this would work and it is a bug?

Environment

Additional context

Please let me know if there is something I missed in the docs about this, thank you for any help!

yzh119 commented 1 year ago

@jparismorgan Sorry, the gptj codebase is a little bit outdated, I'll fix it.

jparismorgan commented 1 year ago

Thank you @yzh119! Just let me know if you have something you'd like help testing / validating, happy to give it a pass!

jparismorgan commented 1 year ago

Hi @yzh119, thanks for the work on this! I just ran and confirmed it's working for the default build, but it looks like something is now up with the webgpu build. Here I am building for osx okay:

(mlc-llm) ~/repo/mlc-llm python3 build.py --model gpt-j-6b --quantization q4f16_0                                                                                                
Using path "dist/models/gpt-j-6b" for model "gpt-j-6b"
Database paths: ['log_db/rwkv-raven-3b', 'log_db/redpajama-3b-q4f16', 'log_db/redpajama-3b-q4f32', 'log_db/rwkv-raven-1b5', 'log_db/dolly-v2-3b', 'log_db/rwkv-raven-7b', 'log_db/vicuna-v1-7b']
[09:54:01] /Users/runner/work/package/package/tvm/src/runtime/metal/metal_device_api.mm:165: Intializing Metal device 0, name=AMD Radeon Pro 555X
[09:54:01] /Users/runner/work/package/package/tvm/src/runtime/metal/metal_device_api.mm:165: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
Host CPU dection:
  Target triple: x86_64-apple-darwin22.3.0
  Process triple: x86_64-apple-darwin22.3.0
  Host CPU: skylake
Target configured: metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32
Load cached module from dist/gpt-j-6b-q4f16_0/mod_cache_before_build_metal.pkl and skip tracing. You can use --use-cache=0 to retrace
Finish exporting to dist/gpt-j-6b-q4f16_0/gpt-j-6b-q4f16_0-metal.so
Finish exporting chat config to dist/gpt-j-6b-q4f16_0/params/mlc-chat-config.json

And here is what happens when I build for webgpu:

(mlc-llm) ~/repo/mlc-llm python3 build.py --model gpt-j-6b --quantization q4f16_0 --target webgpu                                                                                  ✹ ✭main 
Using path "dist/models/gpt-j-6b" for model "gpt-j-6b"
Database paths: ['log_db/rwkv-raven-3b', 'log_db/redpajama-3b-q4f16', 'log_db/redpajama-3b-q4f32', 'log_db/rwkv-raven-1b5', 'log_db/dolly-v2-3b', 'log_db/rwkv-raven-7b', 'log_db/vicuna-v1-7b']
Target configured: webgpu -keys=webgpu,gpu -max_num_threads=256
[12:35:04] /Users/runner/work/package/package/tvm/include/tvm/topi/transform.h:1075: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound
[12:35:05] /Users/runner/work/package/package/tvm/include/tvm/topi/transform.h:1075: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound
[12:47:51] /Users/runner/work/package/package/tvm/src/runtime/metal/metal_device_api.mm:165: Intializing Metal device 0, name=AMD Radeon Pro 555X
[12:47:51] /Users/runner/work/package/package/tvm/src/runtime/metal/metal_device_api.mm:165: Intializing Metal device 1, name=Intel(R) UHD Graphics 630
Host CPU dection:
  Target triple: x86_64-apple-darwin22.3.0
  Process triple: x86_64-apple-darwin22.3.0
  Host CPU: skylake
Automatically using target for weight quantization: metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32
Start computing and quantizing weights... This may take a while.
transformer.ln_f.weight
transformer.ln_f.bias
lm_head.weight
lm_head.bias
transformer.h.9.ln_1.weight
transformer.h.9.ln_1.bias
transformer.h.9.mlp.fc_out.weight
transformer.h.9.mlp.fc_out.bias
transformer.h.9.mlp.fc_in.weight
transformer.h.9.mlp.fc_in.bias
transformer.h.9.attn.out_proj.weight
transformer.h.9.attn.k_proj.weight
transformer.h.9.attn.v_proj.weight
transformer.h.9.attn.q_proj.weight
transformer.h.8.ln_1.weight
transformer.h.8.ln_1.bias
transformer.h.8.mlp.fc_out.weight
transformer.h.8.mlp.fc_out.bias
transformer.h.8.mlp.fc_in.weight
transformer.h.8.mlp.fc_in.bias
transformer.h.8.attn.out_proj.weight
transformer.h.8.attn.k_proj.weight
transformer.h.8.attn.v_proj.weight
transformer.h.8.attn.q_proj.weight
transformer.h.7.ln_1.weight
transformer.h.7.ln_1.bias
transformer.h.7.mlp.fc_out.weight
transformer.h.7.mlp.fc_out.bias
transformer.h.7.mlp.fc_in.weight
transformer.h.7.mlp.fc_in.bias
transformer.h.7.attn.out_proj.weight
transformer.h.7.attn.k_proj.weight
transformer.h.7.attn.v_proj.weight
transformer.h.7.attn.q_proj.weight
transformer.h.6.ln_1.weight
transformer.h.6.ln_1.bias
transformer.h.6.mlp.fc_out.weight
transformer.h.6.mlp.fc_out.bias
transformer.h.6.mlp.fc_in.weight
transformer.h.6.mlp.fc_in.bias
transformer.h.6.attn.out_proj.weight
transformer.h.6.attn.k_proj.weight
transformer.h.6.attn.v_proj.weight
transformer.h.6.attn.q_proj.weight
transformer.h.5.ln_1.weight
transformer.h.5.ln_1.bias
transformer.h.5.mlp.fc_out.weight
transformer.h.5.mlp.fc_out.bias
transformer.h.5.mlp.fc_in.weight
transformer.h.5.mlp.fc_in.bias
transformer.h.5.attn.out_proj.weight
transformer.h.5.attn.k_proj.weight
transformer.h.5.attn.v_proj.weight
transformer.h.5.attn.q_proj.weight
transformer.h.4.ln_1.weight
transformer.h.4.ln_1.bias
transformer.h.4.mlp.fc_out.weight
transformer.h.4.mlp.fc_out.bias
transformer.h.4.mlp.fc_in.weight
transformer.h.4.mlp.fc_in.bias
transformer.h.4.attn.out_proj.weight
transformer.h.4.attn.k_proj.weight
transformer.h.4.attn.v_proj.weight
transformer.h.4.attn.q_proj.weight
transformer.h.3.ln_1.weight
transformer.h.3.ln_1.bias
transformer.h.3.mlp.fc_out.weight
transformer.h.3.mlp.fc_out.bias
transformer.h.3.mlp.fc_in.weight
transformer.h.3.mlp.fc_in.bias
transformer.h.3.attn.out_proj.weight
transformer.h.3.attn.k_proj.weight
transformer.h.3.attn.v_proj.weight
transformer.h.3.attn.q_proj.weight
transformer.h.27.ln_1.weight
transformer.h.27.ln_1.bias
transformer.h.27.mlp.fc_out.weight
transformer.h.27.mlp.fc_out.bias
transformer.h.27.mlp.fc_in.weight
transformer.h.27.mlp.fc_in.bias
transformer.h.27.attn.out_proj.weight
transformer.h.27.attn.k_proj.weight
transformer.h.27.attn.v_proj.weight
transformer.h.27.attn.q_proj.weight
transformer.h.26.ln_1.weight
transformer.h.26.ln_1.bias
transformer.h.26.mlp.fc_out.weight
transformer.h.26.mlp.fc_out.bias
transformer.h.26.mlp.fc_in.weight
transformer.h.26.mlp.fc_in.bias
transformer.h.26.attn.out_proj.weight
transformer.h.26.attn.k_proj.weight
transformer.h.26.attn.v_proj.weight
transformer.h.26.attn.q_proj.weight
transformer.h.25.ln_1.weight
transformer.h.25.ln_1.bias
transformer.h.25.mlp.fc_out.weight
transformer.h.25.mlp.fc_out.bias
transformer.h.25.mlp.fc_in.weight
transformer.h.25.mlp.fc_in.bias
transformer.h.25.attn.out_proj.weight
transformer.h.25.attn.k_proj.weight
transformer.h.25.attn.v_proj.weight
transformer.h.25.attn.q_proj.weight
transformer.h.24.ln_1.weight
transformer.h.24.ln_1.bias
transformer.h.24.mlp.fc_out.weight
transformer.h.24.mlp.fc_out.bias
transformer.h.24.mlp.fc_in.weight
transformer.h.24.mlp.fc_in.bias
transformer.h.24.attn.out_proj.weight
transformer.h.24.attn.k_proj.weight
transformer.h.24.attn.v_proj.weight
transformer.h.24.attn.q_proj.weight
transformer.h.23.ln_1.weight
transformer.h.23.ln_1.bias
transformer.h.23.mlp.fc_out.weight
transformer.h.23.mlp.fc_out.bias
transformer.h.23.mlp.fc_in.weight
transformer.h.23.mlp.fc_in.bias
transformer.h.23.attn.out_proj.weight
transformer.h.23.attn.k_proj.weight
transformer.h.23.attn.v_proj.weight
transformer.h.23.attn.q_proj.weight
transformer.h.22.ln_1.weight
transformer.h.22.ln_1.bias
transformer.h.22.mlp.fc_out.weight
transformer.h.22.mlp.fc_out.bias
transformer.h.22.mlp.fc_in.weight
transformer.h.22.mlp.fc_in.bias
transformer.h.22.attn.out_proj.weight
transformer.h.22.attn.k_proj.weight
transformer.h.22.attn.v_proj.weight
transformer.h.22.attn.q_proj.weight
transformer.h.21.ln_1.weight
transformer.h.21.ln_1.bias
transformer.h.21.mlp.fc_out.weight
transformer.h.21.mlp.fc_out.bias
transformer.h.21.mlp.fc_in.weight
transformer.h.21.mlp.fc_in.bias
transformer.h.21.attn.out_proj.weight
transformer.h.21.attn.k_proj.weight
transformer.h.21.attn.v_proj.weight
transformer.h.21.attn.q_proj.weight
transformer.h.20.ln_1.weight
transformer.h.20.ln_1.bias
transformer.h.20.mlp.fc_out.weight
transformer.h.20.mlp.fc_out.bias
transformer.h.20.mlp.fc_in.weight
transformer.h.20.mlp.fc_in.bias
transformer.h.20.attn.out_proj.weight
transformer.h.20.attn.k_proj.weight
transformer.h.20.attn.v_proj.weight
transformer.h.20.attn.q_proj.weight
transformer.h.2.ln_1.weight
transformer.h.2.ln_1.bias
transformer.h.2.mlp.fc_out.weight
transformer.h.2.mlp.fc_out.bias
transformer.h.2.mlp.fc_in.weight
transformer.h.2.mlp.fc_in.bias
transformer.h.2.attn.out_proj.weight
transformer.h.2.attn.k_proj.weight
transformer.h.2.attn.v_proj.weight
transformer.h.2.attn.q_proj.weight
transformer.h.19.ln_1.weight
transformer.h.19.ln_1.bias
transformer.h.19.mlp.fc_out.weight
transformer.h.19.mlp.fc_out.bias
transformer.h.19.mlp.fc_in.weight
transformer.h.19.mlp.fc_in.bias
transformer.h.19.attn.out_proj.weight
transformer.h.19.attn.k_proj.weight
transformer.h.19.attn.v_proj.weight
transformer.h.19.attn.q_proj.weight
transformer.h.18.ln_1.weight
transformer.h.18.ln_1.bias
transformer.h.18.mlp.fc_out.weight
transformer.h.18.mlp.fc_out.bias
transformer.h.18.mlp.fc_in.weight
transformer.h.18.mlp.fc_in.bias
transformer.h.18.attn.out_proj.weight
transformer.h.18.attn.k_proj.weight
transformer.h.18.attn.v_proj.weight
transformer.h.18.attn.q_proj.weight
transformer.h.17.ln_1.weight
transformer.h.17.ln_1.bias
transformer.h.17.mlp.fc_out.weight
transformer.h.17.mlp.fc_out.bias
transformer.h.17.mlp.fc_in.weight
transformer.h.17.mlp.fc_in.bias
transformer.h.17.attn.out_proj.weight
transformer.h.17.attn.k_proj.weight
transformer.h.17.attn.v_proj.weight
transformer.h.17.attn.q_proj.weight
transformer.h.16.ln_1.weight
transformer.h.16.ln_1.bias
transformer.h.16.mlp.fc_out.weight
transformer.h.16.mlp.fc_out.bias
transformer.h.16.mlp.fc_in.weight
transformer.h.16.mlp.fc_in.bias
transformer.h.16.attn.out_proj.weight
transformer.h.16.attn.k_proj.weight
transformer.h.16.attn.v_proj.weight
transformer.h.16.attn.q_proj.weight
transformer.h.15.ln_1.weight
transformer.h.15.ln_1.bias
transformer.h.15.mlp.fc_out.weight
transformer.h.15.mlp.fc_out.bias
transformer.h.15.mlp.fc_in.weight
transformer.h.15.mlp.fc_in.bias
transformer.h.15.attn.out_proj.weight
transformer.h.15.attn.k_proj.weight
transformer.h.15.attn.v_proj.weight
transformer.h.15.attn.q_proj.weight
transformer.h.14.ln_1.weight
transformer.h.14.ln_1.bias
transformer.h.14.mlp.fc_out.weight
transformer.h.14.mlp.fc_out.bias
transformer.h.14.mlp.fc_in.weight
transformer.h.14.mlp.fc_in.bias
transformer.h.14.attn.out_proj.weight
transformer.h.14.attn.k_proj.weight
transformer.h.14.attn.v_proj.weight
transformer.h.14.attn.q_proj.weight
transformer.h.13.ln_1.weight
transformer.h.13.ln_1.bias
transformer.h.13.mlp.fc_out.weight
transformer.h.13.mlp.fc_out.bias
transformer.h.13.mlp.fc_in.weight
transformer.h.13.mlp.fc_in.bias
transformer.h.13.attn.out_proj.weight
transformer.h.13.attn.k_proj.weight
transformer.h.13.attn.v_proj.weight
transformer.h.13.attn.q_proj.weight
transformer.h.12.ln_1.weight
transformer.h.12.ln_1.bias
transformer.h.12.mlp.fc_out.weight
transformer.h.12.mlp.fc_out.bias
transformer.h.12.mlp.fc_in.weight
transformer.h.12.mlp.fc_in.bias
transformer.h.12.attn.out_proj.weight
transformer.h.12.attn.k_proj.weight
transformer.h.12.attn.v_proj.weight
transformer.h.12.attn.q_proj.weight
transformer.h.11.ln_1.weight
transformer.h.11.ln_1.bias
transformer.h.11.mlp.fc_out.weight
transformer.h.11.mlp.fc_out.bias
transformer.h.11.mlp.fc_in.weight
transformer.h.11.mlp.fc_in.bias
transformer.h.11.attn.out_proj.weight
transformer.h.11.attn.k_proj.weight
transformer.h.11.attn.v_proj.weight
transformer.h.11.attn.q_proj.weight
transformer.h.10.ln_1.weight
transformer.h.10.ln_1.bias
transformer.h.10.mlp.fc_out.weight
transformer.h.10.mlp.fc_out.bias
transformer.h.10.mlp.fc_in.weight
transformer.h.10.mlp.fc_in.bias
transformer.h.10.attn.out_proj.weight
transformer.h.10.attn.k_proj.weight
transformer.h.10.attn.v_proj.weight
transformer.h.10.attn.q_proj.weight
transformer.h.1.ln_1.weight
transformer.h.1.ln_1.bias
transformer.h.1.mlp.fc_out.weight
transformer.h.1.mlp.fc_out.bias
transformer.h.1.mlp.fc_in.weight
transformer.h.1.mlp.fc_in.bias
transformer.h.1.attn.out_proj.weight
transformer.h.1.attn.k_proj.weight
transformer.h.1.attn.v_proj.weight
transformer.h.1.attn.q_proj.weight
transformer.h.0.ln_1.weight
transformer.h.0.ln_1.bias
transformer.h.0.mlp.fc_out.weight
transformer.h.0.mlp.fc_out.bias
transformer.h.0.mlp.fc_in.weight
transformer.h.0.mlp.fc_in.bias
transformer.h.0.attn.out_proj.weight
transformer.h.0.attn.k_proj.weight
transformer.h.0.attn.v_proj.weight
transformer.h.0.attn.q_proj.weight
transformer.wte.weight
transformer.h.0.attn.bias
transformer.h.0.attn.masked_bias
/Users/parismorgan/repo/mlc-llm/mlc_llm/relax_model/gptj.py:711: RuntimeWarning: overflow encountered in cast
  return [(torch_pname, raw_param.astype(dtype))]
transformer.h.1.attn.bias
transformer.h.1.attn.masked_bias
transformer.h.2.attn.bias
transformer.h.2.attn.masked_bias
transformer.h.3.attn.bias
transformer.h.3.attn.masked_bias
transformer.h.4.attn.bias
transformer.h.4.attn.masked_bias
transformer.h.5.attn.bias
transformer.h.5.attn.masked_bias
transformer.h.6.attn.bias
transformer.h.6.attn.masked_bias
transformer.h.7.attn.bias
transformer.h.7.attn.masked_bias
transformer.h.8.attn.bias
transformer.h.8.attn.masked_bias
transformer.h.9.attn.bias
transformer.h.9.attn.masked_bias
transformer.h.10.attn.bias
transformer.h.10.attn.masked_bias
transformer.h.11.attn.bias
transformer.h.11.attn.masked_bias
transformer.h.12.attn.bias
transformer.h.12.attn.masked_bias
transformer.h.13.attn.bias
transformer.h.13.attn.masked_bias
transformer.h.14.attn.bias
transformer.h.14.attn.masked_bias
transformer.h.15.attn.bias
transformer.h.15.attn.masked_bias
transformer.h.16.attn.bias
transformer.h.16.attn.masked_bias
transformer.h.17.attn.bias
transformer.h.17.attn.masked_bias
transformer.h.18.attn.bias
transformer.h.18.attn.masked_bias
transformer.h.19.attn.bias
transformer.h.19.attn.masked_bias
transformer.h.20.attn.bias
transformer.h.20.attn.masked_bias
transformer.h.21.attn.bias
transformer.h.21.attn.masked_bias
transformer.h.22.attn.bias
transformer.h.22.attn.masked_bias
transformer.h.23.attn.bias
transformer.h.23.attn.masked_bias
transformer.h.24.attn.bias
transformer.h.24.attn.masked_bias
transformer.h.25.attn.bias
transformer.h.25.attn.masked_bias
transformer.h.26.attn.bias
transformer.h.26.attn.masked_bias
transformer.h.27.attn.bias
transformer.h.27.attn.masked_bias
Finish computing and quantizing weights.
Total param size: 3.1714653372764587 GB
Start storing to cache dist/gpt-j-6b-q4f16_0/params
[0455/0455] saving param_454
All finished, 101 total shards committed, record saved to dist/gpt-j-6b-q4f16_0/params/ndarray-cache.json
Save a cached module to dist/gpt-j-6b-q4f16_0/mod_cache_before_build_webgpu.pkl.
[13:05:13] /Users/runner/work/package/package/tvm/src/target/llvm/codegen_llvm.cc:185: Warning: Set native vector bits to be 128 for wasm32

Traceback (most recent call last):
  File "/Users/parismorgan/repo/mlc-llm/build.py", line 470, in <module>
    main()
  File "/Users/parismorgan/repo/mlc-llm/build.py", line 462, in main
    build(mod, ARGS)
  File "/Users/parismorgan/repo/mlc-llm/build.py", line 412, in build
    ex.export_library(lib_path, **args.export_kwargs)
  File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/relax/vm_build.py", line 147, in export_library
    return self.mod.export_library(
  File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/runtime/module.py", line 598, in export_library
    return fcompile(file_name, files, **kwargs)
  File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/contrib/emcc.py", line 79, in create_tvmjs_wasm
    raise RuntimeError(msg)
RuntimeError: Compilation error:
wasm-ld: error: initial memory too small, 298422128 bytes needed
emcc: error: '/usr/local/Cellar/emscripten/3.1.41/libexec/llvm/bin/wasm-ld -o dist/gpt-j-6b-q4f16_0/gpt-j-6b-q4f16_0-webgpu.wasm /Users/parismorgan/repo/mlc-llm/3rdparty/tvm/web/dist/wasm/wasm_runtime.bc /Users/parismorgan/repo/mlc-llm/3rdparty/tvm/web/dist/wasm/tvmjs_support.bc /Users/parismorgan/repo/mlc-llm/3rdparty/tvm/web/dist/wasm/webgpu_runtime.bc /var/folders/89/tw4l36q54g9bt_q0pzh8m36m0000gn/T/tmprguxn92d/lib0.o /var/folders/89/tw4l36q54g9bt_q0pzh8m36m0000gn/T/tmprguxn92d/devc.o -L/usr/local/Cellar/emscripten/3.1.41/libexec/cache/sysroot/lib/wasm32-emscripten /usr/local/Cellar/emscripten/3.1.41/libexec/cache/sysroot/lib/wasm32-emscripten/crt1_reactor.o -lGL -lal -lhtml5 -lstandalonewasm-nocatch-memgrow -lstubs -lnoexit -lc -ldlmalloc -lcompiler_rt -lc++-noexcept -lc++abi-noexcept -lsockets -mllvm -combiner-global-alias-analysis=false -mllvm -enable-emscripten-sjlj -mllvm -disable-lsr /var/folders/89/tw4l36q54g9bt_q0pzh8m36m0000gn/T/tmpr6ci9ki4libemscripten_js_symbols.so --import-undefined --strip-debug --export-if-defined=__start_em_asm --export-if-defined=__stop_em_asm --export-if-defined=__start_em_lib_deps --export-if-defined=__stop_em_lib_deps --export-if-defined=__start_em_js --export-if-defined=__stop_em_js --export-if-defined=stackSave --export-if-defined=stackRestore --export-if-defined=stackAlloc --export-if-defined=__errno_location --export-table -z stack-size=65536 --initial-memory=37748736 --entry=_initialize --max-memory=2147483648 --global-base=1024' failed (returned 1)

It looks like this post explains a fix: https://stackoverflow.com/a/66069665/4979029 - but I haven't yet seen where to set that flag - perhaps you know?

Also please let me know if you'd like me to create a new bug for this - thank you!

jparismorgan commented 1 year ago
Click this to see the original comment - minimized it because it was fixed by this other PR! #541 Hi @yzh119, thank you again! I believe I may have found another small issue - it seems that `q4f32_0` quantization doesn't work: ``` (mlc-llm) ~/repo/mlc-llm python build.py --model gpt-j-6b --quantization q4f32_0 Using path "dist/models/gpt-j-6b" for model "gpt-j-6b" Database paths: ['log_db/rwkv-raven-3b', 'log_db/redpajama-3b-q4f16', 'log_db/redpajama-3b-q4f32', 'log_db/rwkv-raven-1b5', 'log_db/dolly-v2-3b', 'log_db/rwkv-raven-7b', 'log_db/vicuna-v1-7b'] [11:56:49] /Users/runner/work/package/package/tvm/src/runtime/metal/metal_device_api.mm:165: Intializing Metal device 0, name=AMD Radeon Pro 555X [11:56:49] /Users/runner/work/package/package/tvm/src/runtime/metal/metal_device_api.mm:165: Intializing Metal device 1, name=Intel(R) UHD Graphics 630 Host CPU dection: Target triple: x86_64-apple-darwin22.3.0 Process triple: x86_64-apple-darwin22.3.0 Host CPU: skylake Target configured: metal -keys=metal,gpu -max_function_args=31 -max_num_threads=256 -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -thread_warp_size=32 [11:56:53] /Users/runner/work/package/package/tvm/include/tvm/topi/transform.h:1076: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound [11:56:54] /Users/runner/work/package/package/tvm/include/tvm/topi/transform.h:1076: Warning: Fast mode segfaults when there are out-of-bounds indices. Make sure input indices are in bound Traceback (most recent call last): File "/Users/parismorgan/repo/mlc-llm/build.py", line 454, in main() File "/Users/parismorgan/repo/mlc-llm/build.py", line 431, in main mod = mod_transform_before_build(mod, params, ARGS) File "/Users/parismorgan/repo/mlc-llm/build.py", line 309, in mod_transform_before_build mod = mlc_llm.transform.FuseDecodeTake()(mod) File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/ir/transform.py", line 238, in __call__ return _ffi_transform_api.RunPass(self, mod) File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__ File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3 File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL AttributeError: Traceback (most recent call last): File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/ir/transform.py", line 307, in _pass_func return inst.transform_module(mod, ctx) File "/Users/parismorgan/repo/mlc-llm/mlc_llm/transform/decode_take.py", line 103, in transform_module mod = relax.transform.FuseOpsByPattern( File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/ir/transform.py", line 238, in __call__ return _ffi_transform_api.RunPass(self, mod) File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__ File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3 File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback File "/Users/parismorgan/repo/mlc-llm/mlc_llm/transform/decode_take.py", line 22, in pattern_check decode.args[0], relax.GlobalVar File "/Users/parismorgan/virtualenvs/mlc-llm/lib/python3.9/site-packages/tvm/runtime/object.py", line 75, in __getattr__ raise AttributeError(f"{type(self)} has no attribute {name}") from None AttributeError: has no attribute args ``` The issue is coming from this code: ``` def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: take = ctx.annotated_expr["take"] print('ctx', ctx) print('take', take) print('take.args[0]', take.args[0], isinstance(take.args[0], relax.GlobalVar)) decode = ctx.annotated_expr["decode"] print('decode', decode) print('decode.args[0]', decode.args[0], isinstance(decode.args[0], relax.GlobalVar)) print() if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( decode.args[0], relax.GlobalVar ): return False return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint ``` When I run this with `q4f16_0` I only see `pattern_check()` called twice: ``` ctx relax.transform.PatternCheckContext(0x7fb04322e560) take R.call_tir(take1, (lv655, lv1612), out_sinfo=R.Tensor((1, 4096), dtype="float16")) take.args[0] I.GlobalVar("take1") True decode R.call_tir(fused_decode1, (lv653, lv654), out_sinfo=R.Tensor((50400, 4096), dtype="float16")) decode.args[0] I.GlobalVar("fused_decode1") True ctx relax.transform.PatternCheckContext(0x7fb0429df4a0) take R.call_tir(take, (lv2, lv), out_sinfo=R.Tensor((n, 4096), dtype="float16")) take.args[0] I.GlobalVar("take") True decode R.call_tir(fused_decode1, (lv, lv1), out_sinfo=R.Tensor((50400, 4096), dtype="float16")) decode.args[0] I.GlobalVar("fused_decode1") True ``` But when I run with `q4f32_0`, `pattern_check()` is called three times, and the last time I see `decode lv7` (instead of something like `decode R.call_tir(fused_decode1, (lv, lv1), out_sinfo=R.Tensor((50400, 4096), dtype="float32"))`), and `decode` doesn't have an `args` key and so crashes: ``` ctx relax.transform.PatternCheckContext(0x7fb7090711f0) take R.call_tir(take1, (lv626, lv1469), out_sinfo=R.Tensor((1, 4096), dtype="float32")) take.args[0] I.GlobalVar("take1") True decode R.call_tir(fused_decode1, (lv624, lv625), out_sinfo=R.Tensor((50400, 4096), dtype="float32")) decode.args[0] I.GlobalVar("fused_decode1") True ctx relax.transform.PatternCheckContext(0x7fb72afe2500) take R.call_tir(take, (lv2, lv), out_sinfo=R.Tensor((n, 4096), dtype="float32")) take.args[0] I.GlobalVar("take") True decode R.call_tir(fused_decode1, (lv, lv1), out_sinfo=R.Tensor((50400, 4096), dtype="float32")) decode.args[0] I.GlobalVar("fused_decode1") True ctx relax.transform.PatternCheckContext(0x7fb7088f09f0) take R.call_tir(NT_matmul, (lv8, lv6), out_sinfo=R.Tensor((1, n, 4096), dtype="float32")) take.args[0] I.GlobalVar("NT_matmul") True decode lv7 ``` I'm not quite sure what the fix would be here? Though I would think that it is an unexpected error because 32 bit quantization seems like something that should be supported? Anyways, thank you for your work here! And please let me know if you'd like me to close this original bug and completed and open two new ones.
jparismorgan commented 1 year ago

I believe the first error above is still happening, but it is not a blocker, and it is slightly different than the original bug I filed, so I'm going to close this issue and open a new one if I need to in the future. Thanks for your help here!