Open wise-east opened 1 year ago
Hello there @wise-east! You seem to be working with a modified code to support the falcon architecture, have you maybe resolved #532 on your own?
I tested your config with only change of replacing model tiiuae/falcon-7b-instruct
with reciprocate/tiny-llama
while keeping max_new_tokens=256
and I couldn't observe this error, so it seems to be falcon specific perhaps. Also which version of transformers did you use?
Yes, I adapted modeling_ppo.py
by adding a FalconModelBranch
based on what was done for other branches.
from transformers.models.falcon import modeling_falcon
class FalconModelBranch(ModelBranch):
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = modeling_falcon._make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = modeling_falcon._expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def forward(
self,
hidden_states: torch.Tensor,
output_shape: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
):
"""
Reference: https://huggingface.co/tiiuae/falcon-7b-instruct/blob/main/modelling_RW.py
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = hidden_states.shape[:2]
if past_key_values is None:
past_key_values = tuple([None] * len(self.decoder_blocks))
head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = modeling_falcon.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.final_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
where modeling_falcon
comes from https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py.
My pip show transformers
result:
Name: transformers
Version: 4.32.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /opt/conda/envs/vllm/lib/python3.9/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, trl, trlx
Oh that's good! However I'm seeing a bit different version of _expand_mask
(https://github.com/huggingface/transformers/blob/aea761499f4b1193f2706f471442da6f9df65d65/src/transformers/models/falcon/modeling_falcon.py#L197) which doesn't take tgt_length
as an argument on 4.32.1
and also on master
. If it's not too much work could you open a pr with your changes? 🙏
🐛 Describe the bug
Here's my
TrainConfig
:Simply changing
max_new_tokens
from 128 to 256 leads to an error:Any help is appreciated!
Which trlX version are you using?
0.7.0
Additional system and package information
Python=3.9