Open Mon-ius opened 5 months ago
@alanwaketan can you take this one?
Any progress?
If that is possible to directly implement the FSDP to, for example, gemma, the original model by applying FSDP
, what's the best practice for parameters of FullyShardedDataParallel
should be?
fsdp_model = FullyShardedDataParallel(
GemmaModel(config, world_size, rank),
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)
But with the sample pretrained ckpt
file, do we have such hook on FSDP can perform as Gemma manually does,
def load_weights(self, model_path: str):
checkpoint = torch.load(model_path, weights_only=True)
model_state_dict = checkpoint['model_state_dict']
num_attn_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.config.head_dim
hidden_size = self.config.hidden_size
def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
axis_len = tensor.shape[axis]
split_len = axis_len // self.world_size
split_start = split_len * self.rank
split_end = split_start + split_len
tensor = torch.moveaxis(tensor, axis, 0)
tensor = tensor[split_start:split_end, ...]
tensor = torch.moveaxis(tensor, 0, axis)
return tensor
for k, v in model_state_dict.items():
if k == 'freqs_cis':
continue
if (k == 'model.norm.weight' or re.fullmatch(
r'model.layers.\d+.input_layernorm.weight', k)
or re.fullmatch(
r'model.layers.\d+.post_attention_layernorm.weight',
k) or k.endswith('weight_scaler')):
pass
elif (k == 'embedder.weight' or re.fullmatch(
r'model.layers.\d+.mlp.down_proj.weight', k)):
v = split(v, 1)
elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)
or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):
v = split(v, 0)
elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',
k):
if num_kv_heads <= self.world_size:
num_replicas = self.world_size // num_kv_heads
v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,
hidden_size)
query = v[:num_attn_heads, ...]
key = v[num_attn_heads:num_attn_heads + num_kv_heads,
...].repeat(num_replicas, 1, 1)
value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)
v = torch.cat(
(split(query, 0), split(key, 0), split(value, 0)),
dim=0)
else:
v = v.reshape(3, num_attn_heads, head_dim, hidden_size)
v = split(v, 1)
v = v.reshape(-1, hidden_size)
elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):
v = v.reshape(hidden_size, num_attn_heads, head_dim)
v = split(v, 1)
v = v.reshape(hidden_size, -1)
else:
raise ValueError(f'Unrecognized key: {k}')
self.state_dict()[k].copy_(v)
@Mon-ius Please take a look at FSDPv2 and use the HF Gemma for pre-training/fine-tuning: https://huggingface.co/blog/gemma-peft
@alanwaketan Thx for this information. Does this trained Gemma
with xla_fsdp_v2
can be compatible both in CUDA
abd TPU
device?
@Mon-ius Yea, as long as you have the correct checkpointing format.
@alanwaketan Do we have a cure to perform fsdp2 in a fully automatic mode? In the given example, we need to specify the single module which will be wrapped for example "q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"
, I searched for days, and seems no one mentioned this point.
Can we have a hook
that can apply DFS
/BFS
on given model
, or go more deeply, do Pytorch has such tree structure to store the children nn.module
?
For a more practical case, considering this code snippet, how we should leverage FSDPv2
/SPMD
here in best practice,
import torch.nn as nn
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.functional as F
from abc import abstractmethod
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return nn.GroupNorm(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, channels, channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1)
else:
self.op = avg_pool_nd(stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(self, channels, num_heads=1, use_checkpoint=False):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention()
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1] // 3
q, k, v = th.split(qkv, ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
return th.einsum("bts,bcs->bct", weight, v)
@staticmethod
def count_flops(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
num_heads=1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch, use_checkpoint=use_checkpoint, num_heads=num_heads
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
self.input_blocks.append(
TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
)
input_block_chans.append(ch)
ds *= 2
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
layers = [
ResBlock(
ch + input_block_chans.pop(),
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
)
)
if level and i == num_res_blocks:
layers.append(Upsample(ch, conv_resample, dims=dims))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def forward(self, x, timesteps, y=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
cat_in = th.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
h = h.type(x.dtype)
return self.out(h)
def dfs(model, prefix=""):
"""
Finds all sub-elements (nn.Module instances) within a PyTorch model, including the model itself.
Args:
model (nn.Module or list or dict or tuple): The PyTorch model or a container holding PyTorch models.
prefix (str, optional): A prefix to prepend to the module names.
Returns:
list: A list of tuples, where each tuple contains the module prefix and the module instance.
"""
modules = []
stack = [(model, prefix)]
while stack:
curr_module, curr_prefix = stack.pop()
if isinstance(curr_module, nn.Module):
modules.append((curr_prefix, curr_module))
for name, child in curr_module.named_children():
child_prefix = f"{curr_prefix}.{name}" if curr_prefix else name
stack.append((child, child_prefix))
elif isinstance(curr_module, (list, tuple)):
for i, item in enumerate(curr_module):
item_prefix = f"{curr_prefix}[{i}]" if curr_prefix else str(i)
stack.append((item, item_prefix))
elif isinstance(curr_module, dict):
for key, value in curr_module.items():
value_prefix = f"{curr_prefix}.{key}" if curr_prefix else key
stack.append((value, value_prefix))
return modules
model = UNetModel(
in_channels=3,
model_channels=64,
out_channels=3,
num_res_blocks=2,
attention_resolutions=(2, 4),
dropout=0.1,
channel_mult=(1, 2, 4, 8),
num_classes=None,
use_checkpoint=False,
num_heads=4,
num_heads_upsample=-1,
use_scale_shift_norm=True,
)
all_modules = dfs(model)
for prefix, module in all_modules:
print(f"{prefix}: {module}")
@Mon-ius You can take a look at the FSDPv1 blog post on auto-wrapping. https://pytorch.org/blog/pytorch-2.0-xla/#fsdp-beta
FSDPv2 re-uses the same auto-wrapping infrastructure. That should solve your problem. In the Gemma example, you just need to specify the GemmaDecoderLayer for every instance to be auto wrapped by FSDPv2.
you just need to specify the GemmaDecoderLayer for every instance to be auto wrapped by FSDPv2
@alanwaketan that is what I means, do we have such helper function that can automatic detect such GemmaDecoderLayer
when a model was wrapped inside, for example, model = FSDPv2(model)
, instead of we need to manually specify that.
@Mon-ius No, the complexity will be worth building a compiler pass that determines the backbone of the module and then wrap them...
@alanwaketan do we have such static compiler? or there is similar but not here for torch
user
@Mon-ius Unfortunately no...
I found the latest opensource LLM from google: Gemma has two version of model structure.
where the
model_xla
version withrun_xla.sh
andxla_model_parallel.py
seems usedXLA
1.X version with modified Transformer network.Beside, I found the main modified part is related to replace official
nn.Linear
part with:Do we still need to perform such job to fit the our model to be trained on
XLA
device?Or there existed such hooks inside the XLA lib and we just do similar thing like FSDP introduced 🤗,
Can we have a doc to have directly implement Gemma with XLA
pjrt
feature without heavy modification as Gemma_XLA did?