pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 451 forks source link

How to implement parrallel training across TPU device with XLA 2.X #6766

Open Mon-ius opened 5 months ago

Mon-ius commented 5 months ago

I found the latest opensource LLM from google: Gemma has two version of model structure.

  1. https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py
  2. https://github.com/google/gemma_pytorch/blob/main/gemma/model.py

where the model_xla version with run_xla.sh and xla_model_parallel.py seems used XLA 1.X version with modified Transformer network.

Beside, I found the main modified part is related to replace official nn.Linear part with:

ColumnParallelLinear
ParallelEmbedding
RowParallelLinear

Do we still need to perform such job to fit the our model to be trained on XLA device?

Or there existed such hooks inside the XLA lib and we just do similar thing like FSDP introduced 🤗,

 fsdp_wrap = lambda m: FSDP(
      m,
      compute_dtype=getattr(torch, FLAGS.compute_dtype),
      fp32_reduce_scatter=FLAGS.fp32_reduce_scatter,
      flatten_parameters=FLAGS.flatten_parameters,
      shard_param_on_dim_0=FLAGS.shard_param_on_dim_0,
      pin_layout_in_collective_ops=FLAGS.pin_layout_in_collective_ops,
      auto_wrap_policy=auto_wrap_policy,
      auto_wrapper_callable=auto_wrapper_callable)

model = fsdp_wrap(model)

Can we have a doc to have directly implement Gemma with XLA pjrt feature without heavy modification as Gemma_XLA did?

JackCaoG commented 5 months ago

@alanwaketan can you take this one?

Mon-ius commented 5 months ago

Any progress?

Mon-ius commented 5 months ago

If that is possible to directly implement the FSDP to, for example, gemma, the original model by applying FSDP, what's the best practice for parameters of FullyShardedDataParallel should be?

fsdp_model = FullyShardedDataParallel(
   GemmaModel(config, world_size, rank),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

But with the sample pretrained ckpt file, do we have such hook on FSDP can perform as Gemma manually does,

def load_weights(self, model_path: str):
        checkpoint = torch.load(model_path, weights_only=True)
        model_state_dict = checkpoint['model_state_dict']

        num_attn_heads = self.config.num_attention_heads
        num_kv_heads = self.config.num_key_value_heads
        head_dim = self.config.head_dim
        hidden_size = self.config.hidden_size

        def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
            axis_len = tensor.shape[axis]
            split_len = axis_len // self.world_size
            split_start = split_len * self.rank
            split_end = split_start + split_len
            tensor = torch.moveaxis(tensor, axis, 0)
            tensor = tensor[split_start:split_end, ...]
            tensor = torch.moveaxis(tensor, 0, axis)
            return tensor

        for k, v in model_state_dict.items():
            if k == 'freqs_cis':
                continue
            if (k == 'model.norm.weight' or re.fullmatch(
                    r'model.layers.\d+.input_layernorm.weight', k)
                    or re.fullmatch(
                        r'model.layers.\d+.post_attention_layernorm.weight',
                        k) or k.endswith('weight_scaler')):
                pass
            elif (k == 'embedder.weight' or re.fullmatch(
                    r'model.layers.\d+.mlp.down_proj.weight', k)):
                v = split(v, 1)
            elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)
                  or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):
                v = split(v, 0)
            elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',
                              k):
                if num_kv_heads <= self.world_size:
                    num_replicas = self.world_size // num_kv_heads
                    v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,
                                  hidden_size)
                    query = v[:num_attn_heads, ...]
                    key = v[num_attn_heads:num_attn_heads + num_kv_heads,
                            ...].repeat(num_replicas, 1, 1)
                    value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)
                    v = torch.cat(
                        (split(query, 0), split(key, 0), split(value, 0)),
                        dim=0)
                else:
                    v = v.reshape(3, num_attn_heads, head_dim, hidden_size)
                    v = split(v, 1)
                v = v.reshape(-1, hidden_size)
            elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):
                v = v.reshape(hidden_size, num_attn_heads, head_dim)
                v = split(v, 1)
                v = v.reshape(hidden_size, -1)
            else:
                raise ValueError(f'Unrecognized key: {k}')
            self.state_dict()[k].copy_(v)
alanwaketan commented 5 months ago

@Mon-ius Please take a look at FSDPv2 and use the HF Gemma for pre-training/fine-tuning: https://huggingface.co/blog/gemma-peft

Mon-ius commented 5 months ago

@alanwaketan Thx for this information. Does this trained Gemma with xla_fsdp_v2 can be compatible both in CUDA abd TPU device?

alanwaketan commented 5 months ago

@Mon-ius Yea, as long as you have the correct checkpointing format.

Mon-ius commented 5 months ago

@alanwaketan Do we have a cure to perform fsdp2 in a fully automatic mode? In the given example, we need to specify the single module which will be wrapped for example "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", I searched for days, and seems no one mentioned this point.

Can we have a hook that can apply DFS/BFS on given model, or go more deeply, do Pytorch has such tree structure to store the children nn.module?

Mon-ius commented 5 months ago

For a more practical case, considering this code snippet, how we should leverage FSDPv2/SPMD here in best practice,

import torch.nn as nn

import math

import numpy as np
import torch as th
import torch.nn as nn
import torch.functional as F

from abc import abstractmethod

def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)

def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

def normalization(channels):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return nn.GroupNorm(32, channels)

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x

class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2):
        super().__init__()
        self.channels = channels
        self.use_conv = use_conv
        self.dims = dims
        if use_conv:
            self.conv = conv_nd(dims, channels, channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(
                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
            )
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2):
        super().__init__()
        self.channels = channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1)
        else:
            self.op = avg_pool_nd(stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)

class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.

    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.

    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(self, channels, num_heads=1, use_checkpoint=False):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.use_checkpoint = use_checkpoint

        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        self.attention = QKVAttention()
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

class QKVAttention(nn.Module):
    """
    A module which performs QKV attention.
    """

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x C x T] tensor after attention.
        """
        ch = qkv.shape[1] // 3
        q, k, v = th.split(qkv, ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        return th.einsum("bts,bcs->bct", weight, v)

    @staticmethod
    def count_flops(model, _x, y):
        """
        A counter for the `thop` package to count the operations in an
        attention operation.

        Meant to be used like:

            macs, params = thop.profile(
                model,
                inputs=(inputs, timestamps),
                custom_ops={QKVAttention: QKVAttention.count_flops},
            )

        """
        b, c, *spatial = y[0].shape
        num_spatial = int(np.prod(spatial))
        # We perform two matmuls with the same number of ops.
        # The first computes the weight matrix, the second computes
        # the combination of the value vectors.
        matmul_ops = 2 * b * (num_spatial ** 2) * c
        model.total_ops += th.DoubleTensor([matmul_ops])

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.

    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    """

    def __init__(
        self,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        num_heads=1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.num_heads = num_heads
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch, use_checkpoint=use_checkpoint, num_heads=num_heads
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                self.input_blocks.append(
                    TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
                )
                input_block_chans.append(ch)
                ds *= 2

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResBlock(
                        ch + input_block_chans.pop(),
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                        )
                    )
                if level and i == num_res_blocks:
                    layers.append(Upsample(ch, conv_resample, dims=dims))
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )

    def forward(self, x, timesteps, y=None):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)
        for module in self.output_blocks:
            cat_in = th.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
        h = h.type(x.dtype)
        return self.out(h)

def dfs(model, prefix=""):
    """
    Finds all sub-elements (nn.Module instances) within a PyTorch model, including the model itself.

    Args:
        model (nn.Module or list or dict or tuple): The PyTorch model or a container holding PyTorch models.
        prefix (str, optional): A prefix to prepend to the module names.

    Returns:
        list: A list of tuples, where each tuple contains the module prefix and the module instance.
    """
    modules = []
    stack = [(model, prefix)]

    while stack:
        curr_module, curr_prefix = stack.pop()

        if isinstance(curr_module, nn.Module):
            modules.append((curr_prefix, curr_module))

            for name, child in curr_module.named_children():
                child_prefix = f"{curr_prefix}.{name}" if curr_prefix else name
                stack.append((child, child_prefix))

        elif isinstance(curr_module, (list, tuple)):
            for i, item in enumerate(curr_module):
                item_prefix = f"{curr_prefix}[{i}]" if curr_prefix else str(i)
                stack.append((item, item_prefix))

        elif isinstance(curr_module, dict):
            for key, value in curr_module.items():
                value_prefix = f"{curr_prefix}.{key}" if curr_prefix else key
                stack.append((value, value_prefix))

    return modules

model = UNetModel(
    in_channels=3,
    model_channels=64,
    out_channels=3,
    num_res_blocks=2,
    attention_resolutions=(2, 4),
    dropout=0.1,
    channel_mult=(1, 2, 4, 8),
    num_classes=None,
    use_checkpoint=False,
    num_heads=4,
    num_heads_upsample=-1,
    use_scale_shift_norm=True,
)

all_modules = dfs(model)

for prefix, module in all_modules:
    print(f"{prefix}: {module}")
alanwaketan commented 5 months ago

@Mon-ius You can take a look at the FSDPv1 blog post on auto-wrapping. https://pytorch.org/blog/pytorch-2.0-xla/#fsdp-beta

FSDPv2 re-uses the same auto-wrapping infrastructure. That should solve your problem. In the Gemma example, you just need to specify the GemmaDecoderLayer for every instance to be auto wrapped by FSDPv2.

Mon-ius commented 5 months ago

you just need to specify the GemmaDecoderLayer for every instance to be auto wrapped by FSDPv2

@alanwaketan that is what I means, do we have such helper function that can automatic detect such GemmaDecoderLayer when a model was wrapped inside, for example, model = FSDPv2(model), instead of we need to manually specify that.

alanwaketan commented 5 months ago

@Mon-ius No, the complexity will be worth building a compiler pass that determines the backbone of the module and then wrap them...

Mon-ius commented 5 months ago

@alanwaketan do we have such static compiler? or there is similar but not here for torch user

alanwaketan commented 5 months ago

@Mon-ius Unfortunately no...