Closed jordane95 closed 7 months ago
Same issue here. Any updates?
Sorry for the silence, a lot of stuff currently to do. Basically, the way it splits blocks accros gpus is by calling the get_block_compute_costs()
.
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.intermediate_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
Since the tiny llama has very small dimensions, this will results in miss-allocations. You can increase the dimension to make it work, i.e:
model_config = LlamaConfig(
# Config for a tiny model model with 1.62M parameters
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=1024,
max_position_embeddings=50277,
num_attention_heads=4,
num_hidden_layers=12,
num_key_value_heads=4,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=50277,
)
If it is for learning purpose, you can manually set the value of splitting like this
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
LlamaDecoderLayer: 1,
# This is the last lm_head
TensorParallelColumnLinear: 0,
}
return block_compute_costs
@3outeille Shoudn't we solve this corner case by at least allocating one layer of parameters on each pp rank?
Running the latest tiny llama conf with larger vocab size would raise the following error
Might be something related to uncorrect pp allocations?