Open tlopex opened 2 weeks ago
@tlopex Thanks! Do you mind fixing the lint errors as shown in CI?
@MasterJH5574 Sorry for being late. I thought I solved the lint issue yesterday. Now there seems something wrong with Model Compilation,
[2024-11-05 10:28:28] INFO compile.py:185: Registering metadata: {'model_type': 'gptj', 'quantization': 'q4f32_1', 'context_window_size': 2048, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 2048, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'kv_state_kind': 'kv_cache', 'max_batch_size': 1}
error: Unsupported RoPE scaling type: gptj
--> /Users/catalyst/Workspace/miniforge3/envs/mlc-llm-ci/lib/python3.8/site-packages/tvm/relax/frontend/nn/llm/kv_cache.py:708:53
|
708 | _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling),
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Compiling with arguments:
--config GPTJConfig(vocab_size=50400, n_embd=4096, n_layer=28, n_head=16, layer_norm_epsilon=1e-05, rotary_dim=64, activation_function='gelu_new', n_inner=None, rope_scaling={'rope_type': 'gptj'}, context_window_size=2048, prefill_chunk_size=2048, tensor_parallel_shards=1, max_batch_size=1, head_dim=0, kwargs={})
--quantization GroupQuantize(name='q4f32_1', kind='group-quant', group_size=32, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float32', linear_weight_layout='NK', quantize_embedding=True, quantize_final_fc=True, num_elem_per_storage=8, num_storage_per_group=4, max_int_value=7, tensor_parallel_shards=0)
--model-type gptj
--target {"thread_warp_size": runtime.BoxInt(32), "host": {"mtriple": "arm64-apple-darwin22.1.0", "tag": "", "kind": "llvm", "mcpu": "apple-m1", "keys": ["arm_cpu", "cpu"]}, "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]}
--opt flashinfer=0;cublas_gemm=0;faster_transformer=0;cudagraph=0;cutlass=0;ipc_allreduce_strategy=NONE
--system-lib-prefix ""
--output /var/folders/n1/5d_r6z251v39vwpj8hj_z1vc0000gp/T/tmpl4pq_51h/lib328.dylib
--overrides context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=1;pipeline_parallel_stages=None
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
[10:28:28] /Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/tvm/src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with remaining blocks!
It is the same problem I met in my own device when I did not update position_embedding
in tvm
. So I think maybe I need to pull a request there.
@tlopex It looks to me that we will need to first send the positional embedding changes to apache/tvm in the file https://github.com/apache/tvm/blob/main/python/tvm/relax/frontend/nn/llm/position_embedding.py. Could you try send your changes to position_embedding.py
there? After we merge the PR there we can bump tvm and follow up on this PR.
@MasterJH5574 Sure, I've already done that. Please take a look at it: https://github.com/apache/tvm/pull/17506 I am just wondering why it could not pass the CI.
This PR supports GPTJ architecture.
The model conversation demonstration is here:
I found that I have to change code in
position_embedding
ofrelax
to run locally. I wonder if I still need a update there.