Open heshenghuan opened 1 year ago
Hi all, I'm recently trying to run the LLaMA-2-70B model in a single GPU, with a lot of help from this project.
But I found that, it is very dangerous to using default non_block=True setting like:
non_block=True
https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L328 https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L332
main code:
tokenizer = LlamaTokenizer.from_pretrained(args.model_dir) model = LlamaForCausalLM.from_pretrained( args.model_dir, low_cpu_mem_usage=True, torch_dtype=DTYPE ).eval() origin_llama_model = model.get_decoder() model.set_decoder( OffloadLlamaModel(origin_llama_model, device=device, num_slices=args.num_slices) ) del origin_llama_model model.lm_head.cuda() # move model.lm_head to GPU prompt = "Give me some suggestions on how to lose weight." input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device=device) logging.info("Generating response...") s = time.time() generate_ids = model.generate( input_ids, do_sample=False, num_beams=1, max_length=200 )
The OffloadLlamaModel code:
class DecodeOutput(object): def __init__(self, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns): super().__init__() self.elements = [ hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns ] def items(self): return self.elements def cuda(self): self.elements = [item.cuda() if hasattr(item, 'cuda') and callable(item.__getattribute__('cuda')) else item for item in self.elements] return self def cpu(self): self.elements = [item.cpu() if hasattr(item, 'cpu') and callable(item.__getattribute__('cpu')) else item for item in self.elements] return self def __str__(self): return "DecodeOutput(" + str(self.elements) + ")" def __getitem__(self, index: int): return self.elements[index] class WrappedLlamaDecoderLayer(nn.Module): def __init__(self, index: int, decoder: LlamaDecoderLayer): super(WrappedLlamaDecoderLayer, self).__init__() self.idx = index self.decoder = decoder def forward(self, inputs: DecodeOutput): # unpack all parameters [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = inputs.items() if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[self.idx] if past_key_values is not None else None # note: removed code like 'if self.gradient_checkpointing and self.training', so only for inference layer_outputs = self.decoder( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) outputs = DecodeOutput( hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns ) return outputs class OffloadLlamaModel(nn.Module): def __init__(self, llama_model: LlamaModel, device=torch.device('cuda'), offload_device=torch.device("cpu"), num_slices=3, checkpoint_activation=False, num_microbatches=1): logging.info("OffloadLlamaModel Initializing.") super(OffloadLlamaModel, self).__init__() self.config = llama_model.config self.padding_idx = llama_model.padding_idx self.vocab_size = llama_model.vocab_size self.embed_tokens = llama_model.embed_tokens.cuda() logging.info("Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders.") _sequential = nn.Sequential() for idx, decoder in enumerate(llama_model.layers): _sequential.add_module("layer_%d" % idx, WrappedLlamaDecoderLayer(idx, decoder)) self.layers = OffloadModel( model=_sequential, device=device, offload_device=offload_device, num_slices=num_slices, checkpoint_activation=checkpoint_activation, num_microbatches=num_microbatches, ) for sid, slc in enumerate(self.layers.model_slices): logging.debug( f"Shard {sid:d} holds WrappedLlamaDecodeLayer [{','.join(str(m.idx) for m in slc.model_shard)}]" ) self.norm = llama_model.norm.cuda() # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None inputs = DecodeOutput( hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns ) layer_outputs = self.layers.forward(inputs) [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = layer_outputs.items() hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, )
I found that the model generated different responses when using different num_slices settings, even when the random seed fixed.
num_slices
The pairwise_distance of each decoder layer between the original model and the offloaded model was like:
2023-10-27 17:46:28,544 - INFO: Loading LLaMA model and tokenizer. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565 Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00, 5.98s/it] 2023-10-27 17:46:47,599 - INFO: Running model natively. 2023-10-27 17:47:15,826 - INFO: Running OffloadModel using given num_slices setting. 2023-10-27 17:47:15,826 - INFO: OffloadLlamaModel Initializing. 2023-10-27 17:47:16,049 - INFO: Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders. 2023-10-27 17:47:16,052 - INFO: This model has 12688.18M parameters, aiming for 6344.09M parameters per shard 2023-10-27 17:47:39,404 - INFO: Shard 0 holds 6344.09M parameters 2023-10-27 17:47:39,405 - INFO: Shard 1 holds 6344.09M parameters 2023-10-27 17:47:39,412 - DEBUG: Shard 0 holds WrappedLlamaDecodeLayer [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19] 2023-10-27 17:47:39,412 - DEBUG: Shard 1 holds WrappedLlamaDecodeLayer [20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39] Embedding is same: True RMSNorm is same: True Checking layers: Layer 00 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 01 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 02 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 03 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 04 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 05 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 06 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 07 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 08 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 09 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 10 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 11 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 12 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 13 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 14 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 15 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 16 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 17 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001 Layer 18 obj is same: True, attention diff: 0.7080, hidden_state diff: 0.0001 Layer 19 obj is same: True, attention diff: 0.7104, hidden_state diff: 13.1719 Layer 20 obj is same: True, attention diff: 0.0465, hidden_state diff: 19.6094 Layer 21 obj is same: True, attention diff: 0.0314, hidden_state diff: 19.7031 Layer 22 obj is same: True, attention diff: 0.0166, hidden_state diff: 19.9062 Layer 23 obj is same: True, attention diff: 0.0146, hidden_state diff: 20.4062 Layer 24 obj is same: True, attention diff: 0.0164, hidden_state diff: 20.8125 Layer 25 obj is same: True, attention diff: 0.0153, hidden_state diff: 21.2500 Layer 26 obj is same: True, attention diff: 0.0143, hidden_state diff: 21.8281 Layer 27 obj is same: True, attention diff: 0.0151, hidden_state diff: 22.4375 Layer 28 obj is same: True, attention diff: 0.0112, hidden_state diff: 22.9844 Layer 29 obj is same: True, attention diff: 0.0150, hidden_state diff: 23.4531 Layer 30 obj is same: True, attention diff: 0.0098, hidden_state diff: 24.0781 Layer 31 obj is same: True, attention diff: 0.0129, hidden_state diff: 24.6562 Layer 32 obj is same: True, attention diff: 0.0098, hidden_state diff: 25.2656 Layer 33 obj is same: True, attention diff: 0.0164, hidden_state diff: 25.8750 Layer 34 obj is same: True, attention diff: 0.0106, hidden_state diff: 26.5000 Layer 35 obj is same: True, attention diff: 0.0133, hidden_state diff: 27.2188 Layer 36 obj is same: True, attention diff: 0.0166, hidden_state diff: 28.0000 Layer 37 obj is same: True, attention diff: 0.0179, hidden_state diff: 28.9688 Layer 38 obj is same: True, attention diff: 0.7056, hidden_state diff: 30.0312 Layer 39 obj is same: True, attention diff: 0.6108, hidden_state diff: 83.8750 ['<s>me a examples for how to improve weight and\n'] ['<s>me a examples on how to improve weight fast I']
Once I manually set non_blocking=False, all the above diff disappeared.
non_blocking=False
Hi all, I'm recently trying to run the LLaMA-2-70B model in a single GPU, with a lot of help from this project.
But I found that, it is very dangerous to using default
non_block=True
setting like:https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L328 https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L332
main code:
The OffloadLlamaModel code:
I found that the model generated different responses when using different
num_slices
settings, even when the random seed fixed.The pairwise_distance of each decoder layer between the original model and the offloaded model was like:
Once I manually set
non_blocking=False
, all the above diff disappeared.