[BUG] Tensor are not on the same device when enable cpu activation offload #3679

Open Muggle666 opened 1 year ago

Muggle666 commented 1 year ago

I am trying to enable cpu activation offload when training my custom LLAMA model. However, an error occur: image It seems like some inputs of the attention operation is offloaded in cpu but others are not. I thinks this is automatically handled by deepspeed, so do I use this feature in a wrong way? My model code is like:

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs, checkpoint_activations=False, checkpoint_num_layers=5):
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers
        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps, name='last_norm')
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2

        self.checkpoint_activations = checkpoint_activations
        self.checkpoint_num_layers = checkpoint_num_layers

        if deepspeed.checkpointing.is_configured():
            self.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            self.checkpoint = deepspeed.checkpointing.checkpoint

    # @torch.inference_mode()
    def forward(self, tokens: torch.Tensor):

        def custom(start, end):
            def custom_forward(*inputs):
                layers_ = self.layers[start:end]
                x_ = inputs[0]
                for layer in layers_:
                    x_ = layer(x_, inputs[1], inputs[2])
                return x_

            return custom_forward

        _bsz, seqlen = tokens.shape
        start_pos = 0
        h = self.tok_embeddings(tokens)

        self.freqs_cis =
        freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        if self.checkpoint_activations:
            l = 0
            num_layers = len(self.layers)
            chunk_length = self.checkpoint_num_layers
            while l < num_layers:
                h = self.checkpoint(custom(l, l + chunk_length),
                               h, freqs_cis, mask)
                l += chunk_length
            for layer in self.layers:
                h, freqs_cis, mask = layer(h, freqs_cis, mask)
        h = self.norm(h)
        # (bsz, seq_len, vocab_size)
        output = self.output(h)
        return output.float()

And the deepspeed config is:

  "train_batch_size": 40,
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 10,
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": 3e-4,
      "eps": 1e-8,
      "weight_decay": 0.01

  "activation_checkpointing": {
    "partition_activations": true,
    "contiguous_memory_optimization": true,
    "cpu_checkpointing": true
  "fp16": {
    "enabled": true,
    "min_loss_scale": 1,
    "loss_scale_window": 1000,
    "cpu_offload": true
  "zero_optimization": {
    "stage": 3,
    "offload_param": {
    "device": "nvme",
    "nvme_path": "/foundation_model",
    "pin_memory": true
    "offload_optimizer": {
    "device": "nvme",
    "nvme_path": "/foundation_model",
    "pin_memory": true
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": 26214400,
    "allgather_bucket_size": 1e8,
    "reduce_bucket_size": 1e8,
    "round_robin_gradients": false,
    "stage3_prefetch_bucket_size": 23592960,
    "stage3_param_persistence_threshold": 51200,
    "stage3_max_live_parameters": 2e8,
    "stage3_max_reuse_distance": 2e8,
    "stage3_gather_16bit_weights_on_model_save": true
  "aio": {
      "block_size": 524288,
      "queue_depth": 8,
      "thread_count": 1,
      "single_submit": true,
      "overlap_events": true

  "steps_per_print": 1

Hope someone can help me out, thank you!

tjruwase commented 1 year ago

@Muggle666, can you please share repro steps including full script and command line? Thanks!

Muggle666 commented 1 year ago

The LLAMA model is from And I just try to add lora on it. The full model code is:

from typing import Optional, Tuple
from dataclasses import dataclass
import math

import torch
from torch import nn
import torch.nn.functional as F
from enum import IntEnum
from torch.backends.cuda import sdp_kernel
import loralib as lora
import deepspeed

class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, name=''):
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim)) = name

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    #assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class SDPBackend(IntEnum):
    Enum class for the scaled dot product attention backends.
    ERROR = -1
    MATH = 0

backend_map = {
    SDPBackend.MATH: {
        "enable_math": True,
        "enable_flash": False,
        "enable_mem_efficient": False},
        "enable_math": False,
        "enable_flash": True,
        "enable_mem_efficient": False},
        "enable_math": False,
        "enable_flash": False,
        "enable_mem_efficient": True}

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):

        self.n_local_heads = args.n_heads // 1
        self.head_dim = args.dim // args.n_heads

        self.wq = lora.Linear(args.dim,
                              args.n_heads * self.head_dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

        # self.wq = nn.Linear(args.dim,
        #                     args.n_heads * self.head_dim,
        #                     bias=False)

        # self.wk = nn.Linear(args.dim,
        #                     args.n_heads * self.head_dim,
        #                     bias=False)

        self.wk = lora.Linear(args.dim,
                              args.n_heads * self.head_dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

        # self.wk = lora.Linear(args.dim,
        #                       args.n_heads * self.head_dim,
        #                       r=8, lora_alpha=16)

        self.wv = lora.Linear(args.dim,
                              args.n_heads * self.head_dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

        # self.wv = nn.Linear(args.dim,
        #                     args.n_heads * self.head_dim,
        #                     bias=False)

        self.wo = lora.Linear(args.dim,
                              args.n_heads * self.head_dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

        # self.wo = nn.Linear(args.n_heads * self.head_dim,
        #                     args.dim,
        #                     bias=False)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        keys = xk
        values = xv

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
            output = F.scaled_dot_product_attention(xq, keys, values, attn_mask=None, dropout_p=0.0, is_causal=True)

        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

class FeedForward(nn.Module):
    def __init__(
            dim: int,
            hidden_dim: int,
            multiple_of: int,
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        # self.w1 = nn.Linear(dim, hidden_dim, bias=False)

        self.w1 = lora.Linear(dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

        # self.w2 = nn.Linear(hidden_dim, dim, bias=False)

        self.w2 = lora.Linear(hidden_dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

        # self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = lora.Linear(dim,
                              r=32, lora_alpha=64, lora_dropout=0.05)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps, name='attn_norm')
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps, name='ffn_norm')

    def forward(self, x: torch.Tensor, freqs_cis=None, mask=None):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cis, mask)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out, freqs_cis, mask

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs, checkpoint_activations=False, checkpoint_num_layers=5):
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers
        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps, name='last_norm')
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2

        self.checkpoint_activations = checkpoint_activations
        self.checkpoint_num_layers = checkpoint_num_layers

        if deepspeed.checkpointing.is_configured():
            self.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            self.checkpoint = deepspeed.checkpointing.checkpoint

    # @torch.inference_mode()
    def forward(self, tokens: torch.Tensor):

        def custom(start, end):
            def custom_forward(*inputs):
                layers_ = self.layers[start:end]
                x_ = inputs[0]
                for layer in layers_:
                    x_, o1, o2 = layer(x_, inputs[1], inputs[2])
                return x_

            return custom_forward

        _bsz, seqlen = tokens.shape
        start_pos = 0
        h = self.tok_embeddings(tokens)

        self.freqs_cis =
        freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        if self.checkpoint_activations:
            l = 0
            num_layers = len(self.layers)
            chunk_length = self.checkpoint_num_layers
            while l < num_layers:
                h = self.checkpoint(custom(l, l + chunk_length),
                               h, freqs_cis, mask)
                l += chunk_length
            for layer in self.layers:
                h, freqs_cis, mask = layer(h, freqs_cis, mask)
        h = self.norm(h)
        # (bsz, seq_len, vocab_size)
        output = self.output(h)
        return output.float()

My main training script is as:

# -- coding:utf-8 -
import torch
import argparse
import torch.multiprocessing as mp
from import DistributedSampler
import torch.distributed as dist
from import Dataset
from torch.nn.parallel import DistributedDataParallel as DDP
from import DataLoader
from torch.distributed import init_process_group, destroy_process_group
from llama.model_train import Transformer, ModelArgs
from llama.tokenizer import Tokenizer
import numpy as np
import random
from utils import CustomDataset
import deepspeed
import loralib as lora
import h5py
import pickle
import json
import os
import logging
import mpu
from import estimate_zero2_model_states_mem_needs_all_live
from import estimate_zero3_model_states_mem_needs_all_live

# # os.environ['DEEPSPEED_LOG_LEVEL'] = 'DEBUG'
# os.environ["BITSANDBYTES_NOWELCOME"] = "1"
# os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
# os.environ["NCCL_DEBUG"] = "DETAIL"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# ==============config==============
PAD_ID = 2
accumulation_steps = 16
device = 'cuda'

# initialize the distributed group

def ddp_setup():

def load_dataset_hdf5(filename):
    with h5py.File(filename, 'r') as f:
        data = f['data'][:]
        labels = f['labels'][:]
    return CustomDataset(data, labels)

def set_random_seeds(random_seed=0):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_model(ckpt_path, tokenizer_path, model_params, args):
    tokenizer = Tokenizer(model_path=tokenizer_path)

    with open(model_params, "r") as f:
        params = json.loads(

    checkpoint_activations = params.pop("checkpointing")
    checkpoint_num_layers = params.pop("checkpoint_layers")
    if checkpoint_activations:
        # mpu.initialize_model_parallel(int(os.environ["WORLD_SIZE"]))
        deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=params['n_layers'])
        mpu.checkpoint = deepspeed.checkpointing.checkpoint
        mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
        mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed

    model_args: ModelArgs = ModelArgs(
        max_seq_len=CUTOFF_LEN, max_batch_size=BATCH_SIZE, **params
    model_args.vocab_size = tokenizer.n_words
    # torch.set_default_tensor_type(torch.cuda.HalfTensor)

    if args.deepspeed_config['zero_optimization']['stage'] == 3:
        if int(os.environ["RANK"]) == 0:
            weights = torch.load(ckpt_path, map_location=torch.device('cpu'))

            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(weights, "_metadata", None)
            weights = weights.copy()
            if metadata is not None:
                weights._metadata = metadata
            error_msgs = []
            weights = None
            metadata = None
            error_msgs = []
        weights = torch.load(ckpt_path, map_location=torch.device('cpu'))

    def load(module: torch.nn.Module, prefix=""):
        # because zero3 puts placeholders in model params, this context
        # manager gathers (unpartitions) the params of the current layer, then loads from
        # the state dict and then re-partitions them again
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (weights, prefix, local_metadata, True, [], [], error_msgs)
        with, modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:

        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    if args.deepspeed_config['zero_optimization']['stage'] == 3:
            model = Transformer(model_args, checkpoint_activations=checkpoint_activations,
        load(model, prefix="")
        print("load finished!!!")
        model = Transformer(model_args, checkpoint_activations=checkpoint_activations,
        # estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)
        # estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)
        model.load_state_dict(weights, strict=False)

    # weights = torch.load(ckpt_path, map_location=torch.device('cpu'))
    # model.load_state_dict(weights, strict=False)
    # model = model.half()
    del weights

    if checkpoint_activations:
        def make_inputs_require_grad(module, input, output):

        module = next(getattr(model, 'tok_embeddings').named_modules())[-1]


    return model

class Trainer:
    def __init__(
            model: torch.nn.Module,
            train_data: DataLoader,
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.vocab_size = model.vocab_size
        print('check vocab size', self.vocab_size)
        self.train_data = train_data
        self.model = model
        del model

        param_group = [p for p in self.model.parameters() if p.requires_grad]
        print('trainable params:', [n for n, p in self.model.named_parameters() if p.requires_grad])
        # Calculate the total number of parameters in the model
        total_params = sum(p.numel() for p in self.model.parameters())

        # Calculate the total number of trainable parameters in the model
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'Trainable Parameters: {trainable_params}')

        # Calculate the percentage of trainable parameters
        # percentage_trainable = (trainable_params / total_params) * 100

        # print(f'Trainable Parameters: {trainable_params} ({percentage_trainable:.2f}% of total parameters)')
        print('start initialize')
        self.model, self.optimizer, _, _ = deepspeed.initialize(
            args=args, model=self.model, model_parameters=param_group
        print('=============model initialized===================')

    def train(self, max_epoch):
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1)
        step = 0
        gradient_accumulation_steps = 10
        train_losses = []
        validation_losses = []

        for e in range(max_epoch):
            accumulated_loss = 0
            num_micro_batches = 0

            for source, label in self.train_data:
                source =
                label =
                logits = self.model(source)
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = label[..., 1:].contiguous()
                loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))

                accumulated_loss += loss.item()
                num_micro_batches += 1

                if num_micro_batches == gradient_accumulation_steps:
                    # Calculate and print the average loss for the full batch
                    avg_loss = torch.tensor([(accumulated_loss / num_micro_batches)],
                    dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
                    avg_loss = avg_loss.item() / dist.get_world_size()
                    if self.global_rank == 0:
                        print(f'epoch {e} global mean loss:', avg_loss)
                    # Reset the accumulated loss and micro-batch counter
                    accumulated_loss = 0
                    num_micro_batches = 0


                step += 1

            if self.global_rank == 0:
                save_path = f"/home/LM_new/save_model/consolidated_tele_768finetune_v2_epoch{e}.pth"
                original_model = self.model.module
      , save_path)

        return train_losses, validation_losses

def get_train_obj(args, ckpt_path='/home/LM_new//model/merged_state_dict.pth',
    with open(data_path, "rb") as file:
        dataset = pickle.load(file)
    print('data size', len(dataset))

    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False,
    print('data loaded')

    model = get_model(ckpt_path, tokenizer_path, model_params, args)

    return dataloader, model

def main():
    ds_config = '/home/LM_new//model/deepspeed_config.json'
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=int(os.environ["LOCAL_RANK"]),
                        help='Local rank for distributed training')
    parser.add_argument('--config-file', default=ds_config, type=str)
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    args.deepspeed_config = json.load(open(args.config_file, 'r', encoding='utf-8'))
    dataset, model = get_train_obj(args)
    trainer = Trainer(model, dataset, args)

if __name__ == "__main__":
    os.system('echo $DEEPSPEED_LOG_LEVEL')

I use torchrun to run it as : torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=xx.xx.xx.xx --master_port=8080

tjruwase commented 1 year ago

Thanks for sharing these. However, I am running into this import error.

Traceback (most recent call last):
  File "", line 15, in <module>
    from utils import CustomDataset
ModuleNotFoundError: No module named 'utils'
[2023-06-07 18:29:56,799] [INFO] [] Killing subprocess 79135
[2023-06-07 18:29:56,799] [ERROR] [] ['/opt/conda/bin/python', '-u', '', '--local_rank=0'] exits with return code = 1
Muggle666 commented 1 year ago
class CustomDataset(Dataset):
    def __init__(self, data, labels): = data
        self.labels = labels

    def __len__(self):
        return len(

    def __getitem__(self, index):
        sample =[index]
        label = self.labels[index]

        return torch.tensor(sample, dtype=torch.long), torch.tensor(label, dtype=torch.long)

This is the custom dataset class, you can try to feed random tensors to the model just for reproducing. I think it should be same.

tjruwase commented 1 year ago

@Muggle666, apologies for the delay on this. I have some bandwidth to investigate but I am stuck on these file paths image

If you still need this investigated, can you please share the sample content to unblock me? Thanks!

OsebeSammi commented 5 months ago

Faced similar issue when offloading to CPU using config

"zero_optimization": { "stage": 3, "offload_param": { "device": "cpu", "buffer_count": 5, "buffer_size": 1e8, "max_in_cpu": 1e9 }, "offload_optimizer": { "device": "cpu" } } }

I removed deepspeed version 0.14 and installed 0.13 and it worked fine. Please investigate issue when offloading parameters to CPU

chchenhui commented 4 months ago

I encountered the same problem on deepspeed version 0.14.2.

wentinghome commented 4 months ago

I encountered the same problem on deepspeed version 0.14.2.

same here

tjruwase commented 4 months ago

@wentinghome, @chchenhui, @OsebeSammi since this is a long running issue, it would be great if you could provide a repro from your own failure. Thanks!