Open CharlieFRuan opened 1 year ago
I am interested in adding a new model, but after seeing the tutorial, I find myself confused and unsure about how to begin. I have several questions regarding the tutorial:
n_positions
, resid_pdrop
- why weren't they included? Or why are some parameters directly passed into kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
without specifying their data types?The internal structure of nn.module
seems different, like GPT2Attention
in the tutorial. The original implementation differs significantly from the relaxed version. Does this mean that as long as the results are consistent, the implementation is acceptable?
Concerning the functions in the GPT model, some are present, and others are not. For example, class GPT2DoubleHeadsModel(GPT2PreTrainedModel)
in modeling_gpt2.py
is not implemented. Why is this? Essentially, which modules need to be implemented and which do not?
I'm having trouble finding where to specify debug=True
in export_to_tvm()
and jit()
. Can you tell me the location?
Again, regarding parameters, how were the parameters used in the validation stage chosen? They don't seem to be passed in.
As someone new to LLM-MLC, I apologize for the multitude of questions and hope you don't mind.
Hi @tlopex, thanks for the questions!
- Does this mean that as long as the results are consistent, the implementation is acceptable?
That is largely correct, as long as some performance is taken into account.
- There are other parameters like n_positions, resid_pdrop - why weren't they included?
Based on the answer to question 2, some of these parameters are not needed to get the results that are consistent. In this case, resid_pdrop
is used for dropout, which is only used during training -- since we only consider inference in mlc-llm, we do not need it. For n_positions
, we replace it with context_window_size
in mlc-llm so that the same parameter shares the same name across all models.
- Essentially, which modules need to be implemented and which do not?
Typically, we want to implement the module that has the transformer model and a linear head -- since our goal is autoregressive generation. In this case, it is GPT2LMHeadModel
, and for llama, it is LlamaForCausalLM
. The doc string in transformers
may be helpful on what each does (e.g. here is for GPT2LMHeadModel).
- I'm having trouble finding where to specify debug=True in export_to_tvm() and jit(). Can you tell me the location?
Sorry I made a typo; it should be export_tvm()
instead of export_to_tvm()
. Simply specify export_tvm(debug=True)
, and similarly for jit(debug=True)
, in addition to the other arguments.
- Again, regarding parameters, how were the parameters used in the validation stage chosen? They don't seem to be passed in.
That is correct, not all parameters in the config.json
are used.
@CharlieFRuan Thank you for your previous excellent answers, which indeed taught me a lot and allowed me to start.
However, I still encountered some issues today:
tir.var
type needs to be input in attn_spec
. However, after checking the documentation, I initially used tir.Var
and encountered an error.
Then I used spec.Int, it told me that TypeError: __init__() takes 1 positional argument but 2 were given
,just do not know how to solve it.
Below is the part of my validate.py and my QwenAttention
config_dict = {
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"vocab_size": 50257,
"rotary_pct": 0.25,
"rotary_emb_base": 10000,
"kv_channels": 128,
"layer_norm_epsilon": 1e-05,
"context_window_size": 1024
}
qwen_config = QWenConfig(**config_dict)
qwen_attn = QWenAttention(qwen_config)
attn_spec = {
"forward": {
"hidden_states": spec.Tensor([1, 2, qwen_config.hidden_size], dtype="float32"),
"attention_mask": spec.Tensor([1, 1, 2, 2], dtype="float32"),
"total_seq_len": tir.Var("total_seq_len", dtype=int)
}
}
mod, named_params = qwen_attn.export_tvm(spec=attn_spec,debug=True)
for name, param in named_params: print(name, param.shape, param.dtype)
class QWenAttention(nn.Module): def init(self, config): super().init() self.rope_theta = config.rotary_emb_base self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_attention_heads self.query_key_value = nn.Linear( in_features=self.hidden_size, out_features=3 self.num_attention_heads self.head_dim, bias=True, ) self.dense = nn.Linear( self.num_attention_heads * self.head_dim, self.hidden_size, bias=True )
self.k_cache = nn.KVCache(config.context_window_size, [self.num_attention_heads, self.head_dim])
self.v_cache = nn.KVCache(config.context_window_size, [self.num_attention_heads, self.head_dim])
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
total_seq_len: tir.Var
):
batch_size, seq_len, _ = hidden_states.shape
assert batch_size == 1, "Only support batch size 1 at this moment."
# Compute query, key, and value
qkv = self.query_key_value(hidden_states)
qkv = op.reshape(qkv, (batch_size, seq_len, 3 * self.num_attention_heads, self.head_dim))
q, k, v = op_ext.llama_rope(
qkv, total_seq_len, self.rope_theta, self.num_attention_heads, self.num_attention_heads
)
# Update cache
self.k_cache.append(op.squeeze(k, axis=0))
self.v_cache.append(op.squeeze(v, axis=0))
k = self.k_cache.view(total_seq_len)
v = self.v_cache.view(total_seq_len)
output = op_ext.attention(q, k, v, attention_mask)
attn_output = self.dense(output)
return attn_output
Moreover, it seems that the explanations for the functions like `tvm.tir.IntImm` below in the [documentation](https://tvm.apache.org/docs/reference/api/python/tir.html) have not been updated, and their order is reversed.
![图片](https://github.com/mlc-ai/mlc-llm/assets/68688494/d2b59fc3-1900-47ce-aff0-20c02b409d25)
3. What does the `5` in `hf_attn = hf_model.transformer.h[5].attn` mean, and why is the fifth layer chosen?
![图片](https://github.com/mlc-ai/mlc-llm/assets/68688494/912fdb3a-2d37-4e3f-bb4d-08293b610c4f)
4. I'm a bit confused about how to set it up if I follow `mlc_y = torch_attn["forward"](x, mask)` and input `tir.Var`.
Sorry to have so many questions again, I hope to get your answers.
@tlopex Apologies for the late reply. Please keep the questions coming, it'd also be helpful for other people trying to learn the workflow.
- I found that the code in the tutorial is not the latest version, I think I may need to keep with the newest version?
Yep, please use the newest version. The repo is likely to be updated continuously, but the main concept and procedure should be largely the same. Otherwise, we would update the tutorial.
- a tir.var type needs to be input in attn_spec
Perhaps try "total_seq_len": int
instead of "total_seq_len": tir.Var("total_seq_len", dtype=int)
.
- What does the 5 in hf_attn = hf_model.transformer.h[5].attn mean, and why is the fifth layer chosen?
5 is indeed just the fifth layer. There isn't a specific reason for this being picked, we just wanted to use a single layer to demonstrate the validation process. Any layer would work.
- I'm a bit confused about how to set it up if I follow mlc_y = torch_attn["forward"](x, mask) and input tir.Var
Directly passing in an integer should work, e.g. mlc_y = torch_attn["forward"](x, mask, 2)
Let me know if there are other questions!
This is a pinned issue directed to the Model Request Tracking Board.