Closed conceptofmind closed 3 months ago
@conceptofmind Hi Enrico again and thank you for running this experiment
Was the above run done in f32 or f16?
@lucidrains I forgot to mention that I used fp16 in the training above. This is likely one of the causes of numerical instability and NaN for this experiment. Since I did not want to alter the training script, I did not apply any stabilization techniques to shrink the token embedding gradient. For example: x = self.token_emb(x) * self.alpha + self.token_emb(x).detach() * (1 - self.alpha)
. This is what Tsinghua did to help stabilize training with fp16 for GLM130b. I can add this to a new script and post the results for training with fp16 again.
I will run the normal script again with fp32 on an A100 and document the results here as well.
Thank you,
Enrico
@conceptofmind ohh yes the fp16 is the likely cause, as i was in the middle of fixing an underflow issue with the way cosine sim attention was approached in the CUDA kernel (which should be fixed in 0.1.38). If you have time to retry it on the latest version, that would be greatly appreciated!
are you using character level enwik8 for training from this repo, or modifying another gpt2 codebase? if modifying another codebase, could you share the code you have?
there is no need for the gradient shrinking technique from Tsinghua, as the whole idea behind the repository was to explore whether cosine sim attention can bring about greater stability without any cost
@lucidrains I will update the repository and rerun the test again with fp16 enabled. I will post the new fp16 training results. I am using the character level enwiki8 for training from this repository for these tests. I have not made any alterations to the training script except for logging the loss with wandb. I did not want to change anything from your original work in order to remain consistent.
Here are the training results with fp32 on an A100 (40GB) for almost 40k steps:
The training has remained more stable.
I am additionally testing flash-cosine-sim-attention in another GPT-like model, a Vision Transformer, and a PALM model. I will post all of the code and results for these additional tests when I am confident everything meets a certain level of correctness. I will not apply the gradient shrinking technique from Tsinghua to any of these additional tests.
Thank you,
Enrico
Thanks Enrico! Definitely also compare the run to non-cosine sim attention, and obtain a validation curve while you are at it :pray:
@conceptofmind is the GPT-2 run using the pre-layernorm architecture?
@lucidrains I am using the current CosineSimCausalTransformer available in the repository for the GPT-2 run. I believe the architecture used post-norm layers with DeepNorm. I did not see a specific place for a PreNorm wrapper or where pre-layernorm was explicitly defined. I saw that DeepNorm was applied to attn.to_v, attn.to_out, ff[0], and ff[2].
@conceptofmind ah ok! thanks for clearing that up!
@lucidrains If I missed something or you want me to add a PreNorm to the Attention layer. I am more than willing to test with that as well.
@conceptofmind that would actually be great! i'll add a prenorm option tomorrow morning :pray:
@lucidrains Here are the results for fp16 training without pre-layernorm for 30k steps on an A100 (40GB). The recent update greatly improved numerical stability for fp16 training.
Training loss:
Validation Loss (Validating every 100 steps):
I will provide an update for fp16 with pre-layernorm when it gets around 30k steps. I will also train one model with fp32 and pre-layernorm as well as one model with non-cosine sim attention. So an additional 3 baseline tests!
Thank you,
Enrico
@lucidrains Here are the results for fp16 training with pre-layernorm for 30k steps on an A100 (40GB). Training remained more stable as well. I changed to validating every 10 steps as that gave a better idea of the results.
Training loss:
Validation Loss (Validating every 10 steps):
Here is the code for the slight change made to attention to include Pre-LayerNorm:
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
scale = 8,
l2norm_groups = 1,
use_cuda_kernel = False,
**kwargs
):
super().__init__()
inner_dim = dim_head * heads
self.scale = scale
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.l2norm_groups = l2norm_groups
self.attn_fn = plain_cosine_sim_attention if not use_cuda_kernel else partial(flash_cosine_sim_attention, **kwargs)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
h, scale, l2norm_groups = self.heads, self.scale, self.l2norm_groups
# pre layernorm
x = self.norm(x)
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
o = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)
o = rearrange(o, 'b h n d -> b n (h d)')
return self.to_out(o)
I am training the model with fp32 and pre-layernorm now.
Thank you,
Enrico
Here are the results for fp32 training with pre-layernorm for 30k steps on an A100 (40GB).
Training loss:
Validation Loss (Validating every 10 steps):
Sidenote I am going to spend the next weeks working on a Triton version of Flash Cosine Similarity Attention as well. I think it would be an interesting comparative benchmark!
Nice! I added the prenorm option in the transformer this morning as a simple flag https://github.com/lucidrains/flash-cosine-sim-attention/commit/0c260d1d1aae0e2152b7509ca7aa2940f1f2d1cc
Is it possible to move the plots you have above into the same graph for comparison?
@lucidrains Of course! Here are the grouped graphs.
fp16 training without pre-layernorm (validating every 100)
fp16 training with pre-layernorm (validating every 10)
fp32 training with pre-layernorm (validating every 10)
I can start validating every step if that would be better as well.
I will post the results of training a PALM-like model soon.
Thank you,
Enrico
Results for training standard PaLM on an A100 (40 GB) for 30k steps:
@lucidrains Here is the code for the PaLM model with flash cosine sim attention. The model is currently training and I will update the results likely later tonight or tomorrow morning.
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
from functools import partial
from flash_cosine_sim_attention import flash_cosine_sim_attention
class LayerNorm(nn.Module): def init(self, dim): super().init() self.gamma = nn.Parameter(torch.ones(dim)) self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class Residual(nn.Module): def init(self, fn): super().init() self.fn = fn
def forward(self, x):
return self.fn(x) + x
class RotaryEmbedding(nn.Module): def init(self, dim): super().init() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x): x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t): return (t pos.cos()) + (rotate_half(t) pos.sin())
class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x
class ParallelTransformerBlock(nn.Module): def init( self, dim, dim_head=64, heads=8, scale = 8, l2norm_groups = 1, ff_mult=4, **kwargs ): super().init() self.norm = LayerNorm(dim)
attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
self.attn_fn = partial(flash_cosine_sim_attention, **kwargs)
self.heads = heads
self.scale = scale
self.l2norm_groups = l2norm_groups
self.rotary_emb = RotaryEmbedding(dim_head)
self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
self.ff_out = nn.Sequential(
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False)
)
# for caching rotary embeddings
self.register_buffer("pos_emb", None, persistent=False)
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb
def forward(self, x):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h, scale, l2norm_groups = x.shape[1], x.device, self.heads, self.scale, self.l2norm_groups
# pre layernorm
x = self.norm(x)
# attention queries, keys, values, and feedforward inner
q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# flash cosine similarity attention
out = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.attn_out(out) + self.ff_out(ff)
def PaLM_flash( *, dim, num_tokens, depth, attn_scale = 8, attn_l2norm_groups = 1, dim_head=64, heads=8, ff_mult=4, *kwargs ): net = nn.Sequential( nn.Embedding(num_tokens, dim), [ Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, scale=attn_scale, groups=attn_l2normgroups, **kwargs)) for in range(depth) ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False) )
# they used embedding weight tied projection out to logits, not common, but works
net[-1].weight = net[0].weight
nn.init.normal_(net[0].weight, std=0.02)
return net
@conceptofmind yup that looks good! so for PaLM, because it uses rotary embeddings, the l2 normalization needs to come before the rotation of the queries and keys
something like this in the readme
q, k = l2norm_tensors(q, k)
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
out = self.attn_fn(q, k, v, causal = True, scale = scale, groups = l2norm_groups, **kwargs)
@lucidrains I will make the adjustments to the model to do the l2 normalization before the rotation of the queries and keys, and post the results.
Thank you,
Enrico
@lucidrains Here are the runs for PaLM with flash-cosine-sim-attention.
For 14k steps:
And for the whole training run:
I am working on the ViT now.
@lucidrains Here is the code for a ViT-16 with flash-cosine-sim-attention:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from flash_cosine_sim_attention import flash_cosine_sim_attention
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
out = flash_cosine_sim_attention(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Here is the training script for the ViT on CIFAR10:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms as T
import tqdm
import wandb
from vit_cosine_sim_flash import ViT
wandb.init(project="my-test-project")
DEVICE = 'cuda'
IMAGE_SIZE = 224
BATCH_SIZE = 4
LEARNING_RATE = 6e-4
EPOCHS = 100
train_transform = T.Compose([
T.Resize(IMAGE_SIZE),
T.AutoAugment(policy = T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
test_transform = T.Compose([
T.Resize(IMAGE_SIZE),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
train_dataset = CIFAR10(
root = './cifar_data_train/',
train = True,
download = True,
transform = train_transform,
)
test_dataset = CIFAR10(
root = './cifar_data_train/',
train = False,
download = True,
transform = test_transform,
)
train_loader = DataLoader(
train_dataset,
shuffle = True,
batch_size = BATCH_SIZE,
)
test_loader = DataLoader(
test_dataset,
batch_size = BATCH_SIZE,
)
model = ViT(
image_size = IMAGE_SIZE,
patch_size = 16,
num_classes = 10,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
model.parameters(),
lr = LEARNING_RATE,
)
for epoch in tqdm.tqdm(range(EPOCHS), desc='training'):
epoch_loss = 0
epoch_acc = 0
for images, labels in train_loader:
images = images.to(DEVICE)
labels = labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (outputs.argmax(dim = 1) == labels).float().mean()
epoch_acc += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_acc = 0
epoch_val_loss = 0
for images, labels in test_loader:
images = images.to(DEVICE)
labels = labels.to(DEVICE)
val_output = model(images)
val_loss = criterion(val_output, labels)
acc = (val_output.argmax(dim=1) == labels).float().mean()
epoch_val_acc += acc / len(test_loader)
epoch_val_loss += val_loss / len(test_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_acc:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_acc:.4f}\n"
)
wandb.log({"epoch": epoch, "train loss": epoch_loss, "train acc": epoch_acc, "val loss": epoch_val_loss, "val acc": epoch_val_acc})
I will post the training results soon.
@lucidrains Here are the results for training the ViT-16 with flash-cosine-sim-attention on CIFAR10 for 100 epochs.
Train and Validation loss: Train and Validation accuracy: Training performance and accuracy were great. Validation shows possible overfitting which can be expected.
What do you think about trying flash-cosine-sim-attention in MEGA-pytorch?
Best,
Enrico
:pray: thank you for running these experiments Enrico
do you think you could also run the same experiments against regular attention and compare the curves side by side? i am concerned about expressiveness issues after Robin's experiments
@conceptofmind you can name your wandb runs by doing
wandb.run.name = 'regular attention'
wandb.run.save()
right after wandb.init
@lucidrains Of course. I will train a ViT-16 with regular attention on CIFAR10 for 100 epochs and compare the curves side by side now. I will update you when everything is done running for the ViT experiments. All of these experiments were conducted on an A100 (40GB).
Here is a chart of the train and val losses for PaLM with and without flash-cosine-sim-attention: Here is a link to that chart: https://wandb.ai/please/my-test-project/reports/val-loss-train-loss-22-11-14-23-08-14---VmlldzoyOTcxODQx?accessToken=733c0o8hpgpklqq42phj7rem7rznxcny5tca2q7v7bnbxm9gwppa1p0gv5fbo19n With Label smoothing:
@lucidrains Here are the results for the ViT-16 experiments with and without flash-cosine-sim attention.
For regular attention I used a learning rate of 2e-4. For flash-cosine-sim I tested with a learning rate of 6e-4.
Train/validation loss:
Train/validation accuracy:
I am running flash-cosine-sim again with a learning rate of 2e-4 this time instead. I will provide an update with that soon.
Best,
Enrico
@lucidrains Here are the results for ViT-16 with and without flash-cosine-sim attention on CIFAR10 for 100 epochs with the same learning rate of 2e-4. I am using an A100(40 GB). Definitely good news!
Train/Validation Loss:
Train/Validation Accuracy:
So far from my testing, I have seen better accuracy and faster convergence with flash-cosine-sim attention. I will need to keep training more models and I am thinking about including some other improvements too. Possibly FastLayerNorm from Apex. I am getting everything set up for doing a training run on IMAGENET and will post the code for that soon.
I am still working on the Triton version (It is my first time using Triton) as well.
@conceptofmind thank you Enrico
the results actually look more promising than i expected, if conditions are held equal between the two runs
@conceptofmind if you have some time, do you think you could try flash cosine sim attention on some generative models and see if the FID scores between regular and cosine sim attention differ at all? that is what i worry about, as @rromb showed evidence that loss curves are not everything. however, the ViT accuracy curves you got from above do look good 🙏
@lucidrains Absolutely! I am currently finishing setting everything up for a run on Imagenet with flash-cosine-sim-attention and will hopefully have the results for that soon. Huggingface has a standard version of Imagenet which fortunately can be downloaded in a reasonable amount of time.
Here is the training script for the run on Imagenet with no data augmentation:
import torch
import tqdm
import argparse
import wandb
from datasets import load_dataset
from transformers import AutoFeatureExtractor
from vit import ViT
wandb.init(project="my-test-project")
wandb.run.name = 'regular attention - imagenet'
wandb.run.save()
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default = 4, type = int)
args = parser.parse_args()
BATCH_SIZE = args.batch_size
DEVICE = 'cuda'
EPOCHS = 100
imagenet_1k_train = load_dataset('imagenet-1k', 'train')
imagenet_1k_test = load_dataset('imagenet-1k', 'test')
model_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
def preprocess_function(images):
image_tensors = [image.convert("RGB") for image in images['image']]
inputs = feature_extractor(image_tensors, return_tensors="pt")
inputs['label'] = images['label']
return inputs
def collate_function(batches):
pixel_values = torch.stack([batch['pixel_values'] for batch in batches])
label = torch.tensor([batch['label'] for batch in batches])
return {'pixel_values': pixel_values, 'label': label}
train_dataset = imagenet_1k_train.with_transform(preprocess_function)
test_dataset = imagenet_1k_test.with_transform(preprocess_function)
train_loader = torch.utils.data.DataLoader(
train_dataset['train'],
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_function,
)
test_loader = torch.utils.data.DataLoader(
test_dataset['test'],
batch_size=BATCH_SIZE,
collate_fn=collate_function,
)
model = ViT(
image_size = 224,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-4)
for epoch in tqdm.tqdm(range(EPOCHS), desc='training'):
epoch_loss = 0
epoch_acc = 0
for batch in train_loader:
images = batch['pixel_values'].to(DEVICE)
labels = batch['label'].to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (outputs.argmax(dim = 1) == labels).float().mean()
epoch_acc += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
with torch.no_grad():
epoch_val_acc = 0
epoch_val_loss = 0
for batch in test_loader:
images = batch['pixel_values'].to(DEVICE)
labels = batch['label'].to(DEVICE)
val_output = model(images)
val_loss = criterion(val_output, labels)
acc = (val_output.argmax(dim=1) == labels).float().mean()
epoch_val_acc += acc / len(test_loader)
epoch_val_loss += val_loss / len(test_loader)
print(
f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_acc:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_acc:.4f}\n"
)
wandb.log({"epoch": epoch, "train loss": epoch_loss, "train acc": epoch_acc, "val loss": epoch_val_loss, "val acc": epoch_val_acc})
I am looking into testing flash-cosine-sim-attention with pytorch ddp, deepspeed or oslo for distributed computing.
Is there a specific diffusion or generative model which you want to be run on CIFAR10? I can do a wide range of them as well.
Best,
Enrico
Hi @lucidrains,
Here are the results for training the GPT2 model on an A100 (40 GB). This is a different A100 I have not used before. I left everything the same other than just logging the loss. After around 65k steps there seems to be an exploding/vanishing gradient and loss goes to NaN. Training became more unstable 20k step mark from my few runs.
I will have to test training on A100 (80 GB) as well.
Thank you,
Enrico