facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.18k stars 280 forks source link

It is dangerous to using default non_block=True. #1146

Open heshenghuan opened 1 year ago

heshenghuan commented 1 year ago

Hi all, I'm recently trying to run the LLaMA-2-70B model in a single GPU, with a lot of help from this project.

But I found that, it is very dangerous to using default non_block=True setting like:

https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L328 https://github.com/facebookresearch/fairscale/blob/main/fairscale/experimental/nn/offload.py#L332

main code:

tokenizer = LlamaTokenizer.from_pretrained(args.model_dir)
model = LlamaForCausalLM.from_pretrained(
    args.model_dir,
    low_cpu_mem_usage=True,
    torch_dtype=DTYPE
).eval()

origin_llama_model = model.get_decoder()
model.set_decoder(
    OffloadLlamaModel(origin_llama_model, device=device, num_slices=args.num_slices)
)
del origin_llama_model
model.lm_head.cuda()  # move model.lm_head to GPU

prompt = "Give me some suggestions on how to lose weight."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device=device)

logging.info("Generating response...")
s = time.time()
generate_ids = model.generate(
    input_ids,
    do_sample=False,
    num_beams=1,
    max_length=200
)

The OffloadLlamaModel code:

class DecodeOutput(object):
    def __init__(self, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
                 output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns):

        super().__init__()
        self.elements = [
            hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
            output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns
        ]

    def items(self):
        return self.elements

    def cuda(self):
        self.elements = [item.cuda() if hasattr(item, 'cuda') and callable(item.__getattribute__('cuda')) else item
                         for item in self.elements]
        return self

    def cpu(self):
        self.elements = [item.cpu() if hasattr(item, 'cpu') and callable(item.__getattribute__('cpu')) else item
                         for item in self.elements]
        return self

    def __str__(self):
        return "DecodeOutput(" + str(self.elements) + ")"

    def __getitem__(self, index: int):
        return self.elements[index]

class WrappedLlamaDecoderLayer(nn.Module):
    def __init__(self, index: int, decoder: LlamaDecoderLayer):
        super(WrappedLlamaDecoderLayer, self).__init__()
        self.idx = index
        self.decoder = decoder

    def forward(self, inputs: DecodeOutput):
        # unpack all parameters
        [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
         output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = inputs.items()

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        past_key_value = past_key_values[self.idx] if past_key_values is not None else None

        # note: removed code like 'if self.gradient_checkpointing and self.training', so only for inference
        layer_outputs = self.decoder(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

        outputs = DecodeOutput(
            hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
            output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns
        )

        return outputs

class OffloadLlamaModel(nn.Module):
    def __init__(self, llama_model: LlamaModel, device=torch.device('cuda'), offload_device=torch.device("cpu"),
                 num_slices=3, checkpoint_activation=False, num_microbatches=1):
        logging.info("OffloadLlamaModel Initializing.")
        super(OffloadLlamaModel, self).__init__()
        self.config = llama_model.config
        self.padding_idx = llama_model.padding_idx
        self.vocab_size = llama_model.vocab_size

        self.embed_tokens = llama_model.embed_tokens.cuda()

        logging.info("Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders.")
        _sequential = nn.Sequential()
        for idx, decoder in enumerate(llama_model.layers):
            _sequential.add_module("layer_%d" % idx, WrappedLlamaDecoderLayer(idx, decoder))

        self.layers = OffloadModel(
            model=_sequential,
            device=device,
            offload_device=offload_device,
            num_slices=num_slices,
            checkpoint_activation=checkpoint_activation,
            num_microbatches=num_microbatches,
        )

        for sid, slc in enumerate(self.layers.model_slices):
            logging.debug(
                f"Shard {sid:d} holds WrappedLlamaDecodeLayer [{','.join(str(m.idx) for m in slc.model_shard)}]"
            )

        self.norm = llama_model.norm.cuda()

        # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        inputs = DecodeOutput(
            hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
            output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns
        )
        layer_outputs = self.layers.forward(inputs)

        [hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache,
         output_hidden_states, all_hidden_states, next_decoder_cache, all_self_attns] = layer_outputs.items()

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

I found that the model generated different responses when using different num_slices settings, even when the random seed fixed.

The pairwise_distance of each decoder layer between the original model and the offloaded model was like:

2023-10-27 17:46:28,544 - INFO: Loading LLaMA model and tokenizer.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.98s/it]
2023-10-27 17:46:47,599 - INFO: Running model natively.
2023-10-27 17:47:15,826 - INFO: Running OffloadModel using given num_slices setting.
2023-10-27 17:47:15,826 - INFO: OffloadLlamaModel Initializing.
2023-10-27 17:47:16,049 - INFO: Convert origin LlamaModel.layers to a nn.Sequential of WrappedLlamaDecoders.
2023-10-27 17:47:16,052 - INFO: This model has 12688.18M parameters, aiming for 6344.09M parameters per shard
2023-10-27 17:47:39,404 - INFO: Shard 0 holds 6344.09M parameters
2023-10-27 17:47:39,405 - INFO: Shard 1 holds 6344.09M parameters
2023-10-27 17:47:39,412 - DEBUG: Shard 0 holds WrappedLlamaDecodeLayer [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]
2023-10-27 17:47:39,412 - DEBUG: Shard 1 holds WrappedLlamaDecodeLayer [20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39]
Embedding is same: True
RMSNorm is same: True
Checking layers:
Layer 00 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 01 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 02 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 03 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 04 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 05 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 06 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 07 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 08 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 09 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 10 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 11 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 12 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 13 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 14 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 15 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 16 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 17 obj is same: True, attention diff: 0.0000, hidden_state diff: 0.0001
Layer 18 obj is same: True, attention diff: 0.7080, hidden_state diff: 0.0001
Layer 19 obj is same: True, attention diff: 0.7104, hidden_state diff: 13.1719
Layer 20 obj is same: True, attention diff: 0.0465, hidden_state diff: 19.6094
Layer 21 obj is same: True, attention diff: 0.0314, hidden_state diff: 19.7031
Layer 22 obj is same: True, attention diff: 0.0166, hidden_state diff: 19.9062
Layer 23 obj is same: True, attention diff: 0.0146, hidden_state diff: 20.4062
Layer 24 obj is same: True, attention diff: 0.0164, hidden_state diff: 20.8125
Layer 25 obj is same: True, attention diff: 0.0153, hidden_state diff: 21.2500
Layer 26 obj is same: True, attention diff: 0.0143, hidden_state diff: 21.8281
Layer 27 obj is same: True, attention diff: 0.0151, hidden_state diff: 22.4375
Layer 28 obj is same: True, attention diff: 0.0112, hidden_state diff: 22.9844
Layer 29 obj is same: True, attention diff: 0.0150, hidden_state diff: 23.4531
Layer 30 obj is same: True, attention diff: 0.0098, hidden_state diff: 24.0781
Layer 31 obj is same: True, attention diff: 0.0129, hidden_state diff: 24.6562
Layer 32 obj is same: True, attention diff: 0.0098, hidden_state diff: 25.2656
Layer 33 obj is same: True, attention diff: 0.0164, hidden_state diff: 25.8750
Layer 34 obj is same: True, attention diff: 0.0106, hidden_state diff: 26.5000
Layer 35 obj is same: True, attention diff: 0.0133, hidden_state diff: 27.2188
Layer 36 obj is same: True, attention diff: 0.0166, hidden_state diff: 28.0000
Layer 37 obj is same: True, attention diff: 0.0179, hidden_state diff: 28.9688
Layer 38 obj is same: True, attention diff: 0.7056, hidden_state diff: 30.0312
Layer 39 obj is same: True, attention diff: 0.6108, hidden_state diff: 83.8750
['<s>me a examples for how to improve weight and\n']
['<s>me a examples on how to improve weight fast I']

Once I manually set non_blocking=False, all the above diff disappeared.