Oneflow-Inc / swin-transformer

0 stars 0 forks source link

Swin-Transformer Loss对齐实验 #2

Open rentainhe opened 2 years ago

rentainhe commented 2 years ago

Swin-Transformer Loss对齐实验结果汇总

实验

遇见的问题以及相关解决的PR

Problems Fixed or not Fixed PR
LayerNorm可能存在问题 Not Fixed -
内置的GELU可能存在问题 Fixed https://github.com/Oneflow-Inc/oneflow/pull/7103
Mixup中的inplace操作替换 Fixed https://github.com/Oneflow-Inc/vision/pull/80

Training Experiments

Model Optimizer Dataset Framework Acc@1 Log
Swin-Tiny AdamW CIFAR100 Pytorch 76.77% -
Swin-Tiny AdamW CIFAR100 OneFlow 67.29% -
Swin-Tiny 将 LayerNorm替换为python层拼接的Module AdamW CIFAR100 OneFlow 66.44% log
Swin-Tiny 将激活函数替换为ReLU AdamW CIFAR100 Pytorch - -
Swin-Tiny 将激活函数替换为ReLU AdamW CIFAR100 OneFlow 66.39% log
Swin-Tiny 使用拼的LayerNorm + ReLU AdamW CIFAR100 OneFlow - -

基于 https://github.com/Oneflow-Inc/oneflow/pull/7103 编译的环境进行的实验,关闭FAST_MATH,修复GELU计算

Model Optimizer Dataset Framework Acc@1 Log
Swin-Tiny + ReLU AdamW CIFAR100 Pytorch 75.95% log
Swin-Tiny + ReLU AdamW CIFAR100 OneFlow 74.74% log
Swin-Tiny + GELU AdamW CIFAR100 Pytorch 75.41% log
Swin-Tiny + GELU AdamW CIFAR100 OneFlow 74.45% log
Model Optimizer Dataset Framework Acc@1 Log
Swin-Tiny with error mixup AdamW CIFAR100 OneFlow 67.80% log
Swin-Tiny with fixed mixup AdamW CIFAR100 OneFlow 75.68% log
Swin-Tiny with mixup AdamW CIFAR100 Pytorch 77.54% log
rentainhe commented 2 years ago

SGD优化器下添加grad-norm后的Loss对齐结果

MARD1NO commented 2 years ago

实验均开启模型eval模式,cudnn_determinstic=True

实验1

载入相同的权重,模型开启eval模式,使用随机数据,使用AdamW优化器,Oneflow和Pytorch的结果是完全能对上的 image

实验2

在实验1的基础上使用Cifar100,关闭shuffle,Loss对不齐 image

实验3

固定权重,使用 AdamW 优化器,模型是eval模式,在伪造数据情况下,oneflow前后两次loss一样 image

实验4

固定权重,使用 AdamW 优化器,模型是eval模式,在真实数据情况下,oneflow前后两次loss不一样 image

个人推断:

  1. Adam系列优化器没问题,因为如果存在问题,那无论是真实数据还是伪造数据,都应该出现对不齐的情况
  2. 在使用相同权重,相同数据下,oneflow多次的结果并不是一致,建浩那边也遇到了类似的现象,我认为还是某些地方选择算法有差,导致误差
rentainhe commented 2 years ago

之前的实验结果

实验条件

Results

SGD优化器下的Loss对比

优化器设置

AdamW优化器下的Loss对比

优化器设置

rentainhe commented 2 years ago

loss_compare

Ldpe2G commented 2 years ago

将 flowvision 中的模型基本都过了一遍,每一类模型选一个,做与 torch 的前100次 iter loss 对齐实验

实验设置

基于 loss_compare 分支,loss_compare_other_nets.py 脚本 AdamW(eps=1e-8, betas=(0.9, 0.999), lr=0.001, weight_decay=0.05) eval 模式

loss 基本能完全和torch对齐的模型:

resnet50
alexnet
vgg16
squeezenet1_0
densenet161
inception_v3
googlenet
shufflenet_v2_x1_0
mobilenet_v2
mobilenet_v3_small
resnext50_32x4d
wide_resnet50_2
mnasnet1_0
conv_mixer
res_mlp

不能对齐的模型,一般是几十个iter之后差异会逐渐增大:

vit
swin
cswin
crossformer
pvt

mlp_mixer 

根据实验,可以观察到貌似都是 vit 这一系列的模型会遇到这个问题

Ldpe2G commented 2 years ago

修改 swin 中 layernorm 的实现

实验设置

基于 loss_compare 分支,在 swin_oneflow.py 文件开头加一段代码:

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.eps = eps
        self.weight = nn.Parameter(flow.Tensor(flow.ones(features, dtype=flow.float32)))
        self.bias = nn.Parameter(flow.Tensor(flow.zeros(features, dtype=flow.float32)))
        self.features = features

    def forward(self, x):
        mean = x.sum(-1, keepdim=True)
        std = x.sum(dim=-1, keepdim=True)
        return self.weight * (x - mean) / (std + self.eps) + self.bias

nn.LayerNorm = LayerNorm

同样在 swin_pytorch.py 文件开头加一段代码:


class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.Tensor(torch.ones(features, dtype=torch.float32)))
        self.bias = nn.Parameter(torch.Tensor(torch.zeros(features, dtype=torch.float32)))
        self.features = features

    def forward(self, x):
        mean = x.sum(-1, keepdim=True)
        std = x.sum(dim=-1, keepdim=True)
        return self.weight * (x - mean) / (std + self.eps) + self.bias

nn.LayerNorm = LayerNorm

就是用 拼的 layernorm 替换内置的 layernorm 并且将 mean 和 std 都替换成 sum

实验结果如下,AdamW 的 loss 就能完全对齐了

image

MARD1NO commented 2 years ago

我个人觉得这个实验可能不太具备说服力:

更改了实现以后,这部分数据范围也改变了,可能没到可能产生误差的范围。就像你觉得伪数据没说服力一样,可能伪数据的情况下,没到那个真实数据产生误差的那个范围里。

我做了另外两个实验:

开启FAST_MATH

image

关闭FAST_MATH

image

L1aoXingyu commented 2 years ago

之前在和 megatron 对齐 loss 采用的标准的训练策略,即 warmup + decay,采用 adamw optimizer,同时 base-lr 也比较小,为 1e-4

下面补充两个实验结果,采用固定 lr,另外需要注意的是 megatron 的 layernorm 是自己实现的 fuse 版本

lr=1e-4

image

lr=1e-2

image

可以看到在小学习率下几乎是可以对齐的,这应该是之前 glm 和 bert 可以对齐的原因,但是在学习率大了之后就无法对齐了,而大家使用 adamw 基本不会使用太大的学习率,这应该是一直没有发现问题的原因。

MARD1NO commented 2 years ago

其实SGD在学习率比较大的情况下,在伪造数据上也是对不齐的

L1aoXingyu commented 2 years ago

所以我觉得还是应该采用真实的数据集和训练策略去对齐 loss,自己造的 case 并不能代表真实的数据分布,如果真实数据下,采用相同策略无法对齐,那必然影响最后的精度,这个是需要去解决的

MARD1NO commented 2 years ago

实验1

  1. 把layernorm的elementwise_Affine= False
  2. 使用relu替换gelu
  3. 关闭oneflow的FAST_MATH
  4. batchsize=8 能初步对齐 image

实验2

  1. 把layernorm的elementwise_Affine= True
  2. 使用relu替换gelu
  3. 关闭oneflow的FAST_MATH
  4. batchsize=16

image

实验3

  1. 把layernorm的elementwise_Affine= True
  2. 将gelu的公式替换为和Torch一样的方式
  3. 关闭oneflow的FAST_MATH
  4. batchsize=16

image

实验4

  1. 把layernorm的elementwise_Affine= False
  2. 将gelu的公式替换为和Torch一样的方式
  3. 关闭oneflow的FAST_MATH
  4. batchsize=16 image

修改后的GELU能够正常通过和Torch的自动测试

rentainhe commented 2 years ago

实验1

  1. 把layernorm的elementwise_Affine= False
  2. 使用relu替换gelu
  3. 关闭oneflow的FAST_MATH
  4. batchsize=8 能初步对齐 image

lr设置的是多大,lr可以调到比较夸张的一个值看看

MARD1NO commented 2 years ago

实验1

  1. 把layernorm的elementwise_Affine= False
  2. 使用relu替换gelu
  3. 关闭oneflow的FAST_MATH
  4. batchsize=8 能初步对齐 image

lr设置的是多大,lr可以调到比较夸张的一个值看看

统一用的是代码里swin的AdamW配置,没做任何修改

MARD1NO commented 2 years ago

从dataloader保存出真实数据前一百个iter的数值分布:

image

其中

print(np.array(np.abs(processed_data[processed_data != 0]).min()))

获取最小数值为 0.00490195

MARD1NO commented 2 years ago

关闭FAST_MATH,去除LayerNorm,使用修复后的GELU

使用Adamw优化器

of_optim = flow.optim.AdamW(of_model.parameters(), eps=1e-8, betas=(0.9, 0.999), lr=不同实验不一样 weight_decay=0.05)

lr=1e-6

image

lr=1e-5

image

lr=1e-4

image

lr=1e-3

Torch训练正常,oneflow出现nan image

Ldpe2G commented 2 years ago

精度问题定位实验

实验设置

OneFlow swin 仓库代码分支:swin_clean_ldp,Pytorch 直接采用官方的仓库。

OneFlow 框架的改动和分支 https://github.com/Oneflow-Inc/oneflow/pull/7103/ 类似,却别在于 gelu 实现中采用的是 normcdf 而不是 normcdff,同时关闭 layernorm 和 softmax 中的 fast math 是通过设置 cmake 选项的方式,而不是直接注释代码。

训练的模型:swin_small_patch4_window7_224.yaml, 数据集: Cifar100,Optimizer: AdamW, 8卡 DDP

实验分为4组

实验1,关闭 mixup,关闭 clip_grad,Loss 采用 CrossEntropyLoss

Framework Epoch Top1 Acc Log
OneFlow 95 69.53% cifar100_mixup_off_clip_grad_off.log
Pytorch 95 70.81% cifar100_mixup_off_clip_grad_off.txt

实验2,关闭 mixup,打开 clip_grad,Loss 采用 CrossEntropyLoss

Framework Epoch Top1 Acc Log
OneFlow 95 69.77% cifar100_mixup_off_clip_grad_on.log
Pytorch 95 70.47% cifar100_mixup_off_clip_grad_on.txt

实验3,关闭 mixup,打开 clip_grad,Loss 采用 SoftTargetCrossEntropy

Framework Epoch Top1 Acc Log
OneFlow 95 69.35% cifar100_mixup_off_clip_grad_on_oneehot_softcross.log
Pytorch 95 70.45% cifar100_mixup_off_clip_grad_on_oneehot_softcross.txt

实验4,打开 mixup,关闭 clip_grad,Loss 采用 SoftTargetCrossEntropy

Framework Epoch Top1 Acc Log
OneFlow 95 55.11% cifar100_mixup_on_clip_grad_off.log
Pytorch 95 63.81% cifar100_mixup_on_clip_grad_off.txt

实验结论

上面的4组实验因为是没有完全跑完(完整300个epoch), 所以绝对数值不需要太关注,关注点在于 Pytorch 和 OneFlow 之间的差距。

实验1~3,OneFlow 和 Pytorch 的差距都不大,OneFlow 都是比 Pytorch 要低 1% 左右。

但是实验4在打开了 mixup之后, 虽然 Pytorch 和 OneFlow 的收敛速度都下降了,但是差距却拉大了有 9% 左右的差距。

这里很有可能就是导致最终收敛的精度比 Pytorch 低很多的原因,还需要进一步定位问题。

Ldpe2G commented 2 years ago

上面实验3,训到 epoch 254 的时候触发,clip_grad 算 Norm 的时候出现了 nan

Traceback (most recent call last):
  File "main.py", line 327, in <module>
    main(config)
  File "main.py", line 125, in main
    train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
  File "main.py", line 189, in train_one_epoch
    grad_norm = flow.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
  File "/home/ldp/miniconda3/envs/oneflow-dev-gcc7/lib/python3.6/site-packages/oneflow/nn/utils/clip_grad.py", line 113, in clip_grad_norm_
    f"The total norm of order {norm_type} for gradients from "
RuntimeError: The total norm of order 2.0 for gradients from `parameters` is non-finite, so it cannot be clipped. To disable this error and scale the gradients by the non-finite norm anyway, set `error_if_nonfinite=False`
MARD1NO commented 2 years ago

上面实验3,训到 epoch 254 的时候触发,clip_grad 算 Norm 的时候出现了 nan

Traceback (most recent call last):
  File "main.py", line 327, in <module>
    main(config)
  File "main.py", line 125, in main
    train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
  File "main.py", line 189, in train_one_epoch
    grad_norm = flow.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
  File "/home/ldp/miniconda3/envs/oneflow-dev-gcc7/lib/python3.6/site-packages/oneflow/nn/utils/clip_grad.py", line 113, in clip_grad_norm_
    f"The total norm of order {norm_type} for gradients from "
RuntimeError: The total norm of order 2.0 for gradients from `parameters` is non-finite, so it cannot be clipped. To disable this error and scale the gradients by the non-finite norm anyway, set `error_if_nonfinite=False`

https://github.com/Oneflow-Inc/oneflow/pull/7106 可能这个PR修复了pow就能解决nan问题

Ldpe2G commented 2 years ago

精度问题定位实验

基于实验组 https://github.com/Oneflow-Inc/swin-transformer/issues/2#issuecomment-1001159164

修复了 mixup 之后 https://github.com/Oneflow-Inc/vision/pull/80 增加一个实验

实验5,打开 mixup,打开 clip_grad

Framework Epoch Top1 Acc Log
OneFlow 95 63.00% cifar100_mixup_on_clip_grad_on.log
Pytorch 95 63.80% cifar100_mixup_on_clip_grad_on.txt

实验结论

mixup 实现上有问题是 oneflow 在 cifar100 上训练比 pytorch 低很多的主要原因,但是从5组的实验结果上看 pytorch 的收敛精度平均都要比 oneflow 高 1% 点多左右,这是进一步要去定位问题的地方。

yuanms2 commented 2 years ago

理论上精度、数值计算结果应该是一样的,数据加载是不是一样,超参是不是对齐了,有没有做模块剖分的实验,前向计算结果一样吗? 如果前向不一样,就找前向的问题,前向一样,应该就是看反向计算和optimizer的问题

yuanms2 commented 2 years ago

当实锤找到一个问题之后,在修复的版本上是不是在上面列出的实验结果也重复一遍

Ldpe2G commented 2 years ago

当实锤找到一个问题之后,在修复的版本上是不是在上面列出的实验结果也重复一遍

mixup 是 data augmentation 方法,包含随机性,是完整训练才会用的,单卡的 Loss 不会打开。之前是发现完整训练 oneflow 比 torch 低很多,才进行了上面多组实验去定位。 https://github.com/Oneflow-Inc/swin-transformer/issues/2#issuecomment-1001159164

https://github.com/Oneflow-Inc/swin-transformer/issues/2#issuecomment-1001805751

chengtbf commented 2 years ago
  1. mix up 是纯 python 端的 op? 不需要对 oneflow 框架做改动吗。
  2. 我看了一下修改内容,核心是替换了 x.mul(lam).add(x_flipped) -> x.mul_(lam).add_(x_flipped) 吗。 为什么换成 Inplace 版本,会对精度有明显影响呢 😂
Ldpe2G commented 2 years ago
  1. mix up 是纯 python 端的 op? 不需要对 oneflow 框架做改动吗。
  2. 我看了一下修改内容,核心是替换了 x.mul(lam).add(x_flipped) -> x.mul_(lam).add_(x_flipped) 吗。 为什么换成 Inplace 版本,会对精度有明显影响呢 😂

因为原来没有用 inplace 版本的实现,并没有将 做变换之后的结果 return 出去,导致 网络的 输入和 label 对不上,之前在搬运 mixup 的时候还没有实现对应的 Inplace 的算子,所以又写了个 todo ,但是后来估计是忘记了这里。。。

rentainhe commented 2 years ago

实验

验证LayerNorm

实验环境

1. 使用GELU激活函数,将LayerNorm替换为Identity no layernorm

2. 使用GELU激活函数,使用LayerNorm(elementwise_affine=True) LayerNorm

3. 使用GELU激活函数,使用LayerNorm(elementwise_affine=False) elementwise_off

kaijieshi7 commented 2 years ago

用下面的代码,可以发现用layernorm梯度是不一样的,如果换成BN1D没多大问题

import oneflow as flow
import oneflow.nn as flownn
import torch as torch
import torch.nn as torchnn
import numpy as np

class N1(flownn.Module):
    def __init__(self):
        super(N1, self).__init__()
        self.reduction = flownn.Linear(96, 96, bias=False)
        # 切换norm方式,默认BN1D没问题
        self.norm = flownn.LayerNorm(96)
        # self.norm = flownn.BatchNorm1d(96)
        flownn.init.constant_(self.reduction.weight.data, 0.5)
        flownn.init.constant_(self.norm.weight.data, 0.9)
        flownn.init.constant_(self.norm.bias.data, 0.1)

    def forward(self, x):
        x = self.reduction(x)
        x = self.norm(x)
        return x

class N2(torchnn.Module):
    def __init__(self):
        super(N2, self).__init__()
        self.reduction = torchnn.Linear(96, 96, bias=False)
        # 切换norm方式,默认BN1D没问题
        self.norm = torchnn.LayerNorm(96)
        # self.norm = torchnn.BatchNorm1d(96)
        torchnn.init.constant_(self.reduction.weight.data, 0.5)
        torchnn.init.constant_(self.norm.weight.data, 0.9)
        torchnn.init.constant_(self.norm.bias.data, 0.1)

    def forward(self, x):
        x = self.reduction(x)
        x = self.norm(x)
        return x

#初始化网络
n_flow = N1().cuda()
n_torch = N2().cuda()
# 相同的输入
n = np.random.rand(2, 96).astype('float32')
x_flow = flow.tensor(n).cuda()
x_torch = torch.from_numpy(n).cuda()
# 相同的输出
n_label = np.random.rand(2, 96).astype('float32')
label_flow = flow.tensor(n_label).cuda()
label_torch = torch.from_numpy(n_label).cuda()
# 相同的损失函数
loss_fn_FLOW = flownn.MSELoss()
# loss_fn_FLOW = flownn.BCELoss()
loss_fn_torch = torchnn.MSELoss()
# loss_fn_torch = torchnn.BCELoss()
# 相同的优化器
optimizer_flow = flow.optim.SGD(n_flow.parameters(), lr=0.001, momentum=0.9, weight_decay=0.05)
optimizer_torch = torch.optim.SGD(n_torch.parameters(), lr=0.001, momentum=0.9, weight_decay=0.05)
for i in range(10):
    # 查看norm的可学习参数
    # print(n_flow.norm.weight)
    # print(n_flow.norm.bias)
    optimizer_flow.zero_grad()
    y = n_flow(x_flow)
    loss_fn_FLOW(y, label_flow).backward()

    # print(n_torch.norm.weight)
    # print(n_torch.norm.bias)
    optimizer_torch.zero_grad()
    y = n_torch(x_torch)
    loss_fn_torch(y, label_torch).backward()
    # 查看梯度
    for i, j in zip(optimizer_flow.param_groups[0].parameters, optimizer_torch.param_groups[0]['params']):
        ii = i.grad
        jj = j.grad
    optimizer_flow.step()
    optimizer_torch.step()
yuanms2 commented 2 years ago

是不是意味着layer norm 算子是没对齐的

yuanms2 commented 2 years ago

我们layer norm 是哪位写的,@郑泽康,@郭冉

kaijieshi7 commented 2 years ago

又修改了一下,这段代码的输出结果越小越好,有下面几个实验。这里的误差我从数量级分为3种:1e-5,0.几,几.0

  1. 用BN1d,可以看到输出误差很小:1e-5数量级
  2. 用layernorm(elementwise_affine=False),输出误差依然很小:0.几(误差趋势减小)
  3. 用layernorm(elementwise_affine=True),输出误差比上面两个大很多数量级:几.0 (误差趋势变大)
  4. 用layernorm(elementwise_affine=False)加上自定义的scale和bias,误差比layernorm(elementwise_affine=True)小,但是还是比bn1d大。:0.几(误差趋势减小)

所以暂时找到了一个问题,反向传播时layernorm的可学习参数更新和torch的更新不一致。但应该还有其他问题,因为layernorm(elementwise_affine=False)的误差从数量级上比bn1d大

import oneflow as flow
import oneflow.nn as flownn
import torch as torch
import torch.nn as torchnn
import numpy as np

np.random.seed(1)
linear = np.random.rand(96, 96).astype('float32')
class N1(flownn.Module):
    def __init__(self):
        super(N1, self).__init__()
        self.reduction = flownn.Linear(96, 100, bias=False)
        self.reduction2 = flownn.Linear(100, 100, bias=False)
        # 切换norm方式,默认BN1D没问题
        self.norm = flownn.LayerNorm(100, elementwise_affine=False)
        # self.norm = flownn.BatchNorm1d(100)
        flownn.init.constant_(self.reduction.weight.data, 0.5)
        flownn.init.constant_(self.reduction2.weight.data, 0.5)
        # self.reduction.weight.data = flow.tensor(linear)
        # flownn.init.constant_(self.norm.weight.data, 0.9)
        # flownn.init.constant_(self.norm.bias.data, 0.1)
        # self.p = flownn.Parameter(flow.zeros(96, 96, dtype=flow.float32, requires_grad=True).cuda() + 0.5)
        self.beta = flownn.Parameter(flow.ones(1, 100))
        self.gamma = flownn.Parameter(flow.zeros(1, 100))

    def forward(self, x):
        x = self.reduction(x)
        # x = flow.matmul(x, self.p)
        x = self.norm(x)
        x = x*self.beta + self.gamma
        x = self.reduction2(x)
        return x

class N2(torchnn.Module):
    def __init__(self):
        super(N2, self).__init__()
        self.reduction = torchnn.Linear(96, 100, bias=False)
        self.reduction2 = torchnn.Linear(100, 100, bias=False)
        # 切换norm方式,默认BN1D没问题
        self.norm = torchnn.LayerNorm(100, elementwise_affine=True)
        # self.norm = torchnn.BatchNorm1d(100)
        torchnn.init.constant_(self.reduction.weight.data, 0.5)
        torchnn.init.constant_(self.reduction2.weight.data, 0.5)
        # self.reduction.weight.data = torch.from_numpy(linear)
        # torchnn.init.constant_(self.norm.weight.data, 0.9)
        # torchnn.init.constant_(self.norm.bias.data, 0.1)

    def forward(self, x):
        x = self.reduction(x)
        x = self.norm(x)
        x = self.reduction2(x)
        return x

#初始化网络
n_flow = N1().cuda()
n_torch = N2().cuda()

# 相同的损失函数
loss_fn_FLOW = flownn.MSELoss()
loss_fn_FLOW = flownn.L1Loss()
# loss_fn_FLOW = flownn.BCELoss()

loss_fn_torch = torchnn.MSELoss()
loss_fn_torch = torchnn.L1Loss()
# loss_fn_torch = torchnn.BCELoss()
# 相同的优化器
optimizer_flow = flow.optim.SGD(n_flow.parameters(), lr=0.01,  weight_decay=0.05)
optimizer_torch = torch.optim.SGD(n_torch.parameters(), lr=0.01,  weight_decay=0.05)

for i in range(10000):
    # 相同的输入
    n = np.random.rand(10, 96).astype('float32')
    x_flow = flow.tensor(n).cuda()
    x_torch = torch.from_numpy(n).cuda()
    # 相同的label
    n_label = np.random.rand(10, 100).astype('float32')
    label_flow = flow.tensor(n_label).cuda()
    label_torch = torch.from_numpy(n_label).cuda()
    # 查看norm的可学习参数
    # print(n_flow.norm.weight)
    # print(n_flow.norm.bias)
    optimizer_flow.zero_grad()
    y1 = n_flow(x_flow)
    loss_fn_FLOW(y1, label_flow).backward()

    # print(n_torch.norm.weight)
    # print(n_torch.norm.bias)
    optimizer_torch.zero_grad()
    y2 = n_torch(x_torch)
    loss_fn_torch(y2, label_torch).backward()
    # 查看梯度
    for i, j in zip(optimizer_flow.param_groups[0].parameters, optimizer_torch.param_groups[0]['params']):
        ii = i.grad
        jj = j.grad

    optimizer_flow.step()
    optimizer_torch.step()
    print(abs(y1.numpy()-y2.detach().cpu().numpy()).sum().item())
guo-ran commented 2 years ago

我看了一下pytorch的代码,它的(a / b)除法是用的* (1.f / b),本地修改了一下代码,oneflow不使用FAST_MATH,除法换成乘倒数 可以把diff降到


diff max 1.475215e-06
diff min -1.2516975e-06
guo-ran commented 2 years ago

我先拿修改过的代码重新跑一下swin-transformer loss曲线看看

guo-ran commented 2 years ago

layernorm+ GELU formula and Trun OFF FASTMATH oneflow#7103 分支 的gelu改动的曲线 sgd: image adamw: loss_compare

把softmax的除法也转乘法,不用fast math, AdamW loss_compare

MARD1NO commented 2 years ago

暂时能想到的四组实验:

四组实验

  1. 使用SGD分别跑oneflow和 torch。(天和已经完成,oneflow和torch均不收敛)
  2. 使用Apex版本的elementwise_affine=False的layernorm一起跑,用Adam
  3. 使用郭冉修改的除法,Pytorch版使用torch的layernorm跑,等郭冉提一个分支
  4. 用 BN + reshape 模拟layernorm
guo-ran commented 2 years ago

dev_debug_swin_transformer_gr 这个分支可以试试跑完整收敛训练,跑loss曲线比较接近。 主要改了: 1、layernorm除法转乘法 2、softmax除法转乘法 3、上面zzk分支里gelu的改动 swin-transformer仓库loss_compare分支的AdamW收敛曲线: image

rentainhe commented 2 years ago

dev_debug_swin_transformer_gr 这个分支可以试试跑完整收敛训练,跑loss曲线比较接近。 主要改了: 1、layernorm除法转乘法 2、softmax除法转乘法 3、上面zzk分支里gelu的改动 swin-transformer仓库loss_compare分支的AdamW收敛曲线: image

好的,我这边去尝试一下

guo-ran commented 2 years ago

dev_debug_swin_transformer_gr 这个分支可以试试跑完整收敛训练,跑loss曲线比较接近。 主要改了: 1、layernorm除法转乘法 2、softmax除法转乘法 3、上面zzk分支里gelu的改动 swin-transformer仓库loss_compare分支的AdamW收敛曲线: image 多次实验发现存在一定随机性,有时能完全对齐,有时有差异

Ldpe2G commented 2 years ago

Cifar100完整训练实验

实验设置

oneflow 基于分支: https://github.com/Oneflow-Inc/oneflow/pull/7103

oneflow swin 仓库代码分支:https://github.com/Oneflow-Inc/swin-transformer/pull/4

模型 swin small: swin_small_patch4_window7_224.yaml

layernorm 采用 拼的方式,超参与 torch 对齐, mixup和clip grad均打开。

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.eps = eps
        self.weight = nn.Parameter(flow.ones(features, dtype=flow.float32))
        self.bias = nn.Parameter(flow.zeros(features, dtype=flow.float32))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        return self.weight * (x - mean) * flow.rsqrt(var + self.eps) + self.bias

实验结果

Framework Epoch Top1 Acc Log
OneFlow,加载torch的初始化模型 300 78.7% cifar100_mixup_on_clip_grad_on_torch_init_custom_layernorm_eps_1e-5.log
OneFlow,随机初始化 300 77.72% cifar100_mixup_on_clip_grad_on_custom_layernorm_eps_1e-5.log
Pytorch,随机初始化 300 79% cifar100_mixup_on_clip_grad_on.txt

实验结论

layernorm 采用拼的方式且加载 torch 的初始化模型最终 top1 acc 和 torch 还差 0.3%, 而如果是随机初始化则和 torch 差 ~1.2%。

yuanms2 commented 2 years ago

0.3% 的差距,总算差距没那么大

rentainhe commented 2 years ago

CIFAR100 完整训练实验

实验设置

实验结果

Model Optimizer Framework Acc@1 Log
Swin-Tiny AdamW OneFlow 76.52% log
Swin-Tiny AdamW Pytorch 77.31% log

简单结论

yuanms2 commented 2 years ago

任天和这个实验和上面梁德澎那个实验关系是什么

rentainhe commented 2 years ago

任天和这个实验和上面梁德澎那个实验关系是什么

主要的区别有以下几个:

但是都可以证明对LayerNorm的改动所带来的结果是积极的,精度差异都有在缩小

MARD1NO commented 2 years ago

CIFAR100 完整训练实验

实验设置

实验结果

Model Optimizer Framework Acc@1 Log Swin-Tiny AdamW OneFlow 76.52% log Swin-Tiny AdamW Pytorch 77.31% log

简单结论

  • 相较于之前的分支精度差异从2%缩小到0.7%,这个分支进行的改动带来了积极结果 @MARD1NO @guo-ran

我的一些做实验设置建议:

  1. 固定权重,这里不是说oneflow导入torch的,而是说把torch的模型固定保留一份,后续实验都在这个权重的基础上来做。因为这两周不光是我自己做实验,郭冉昨天做实验,都有遇到某次实验loss曲线能一摸一样,有时候就会有些差距。 在固定权重的情况下做实验,我觉得是来帮助排查算子改动是否积极的。甚至能避免你们需要做多次实验取平均这个步骤(既然权重固定,其他 一致的情况下,多次做实验效果应该是一样的) 到后面要整体训的话,那可以每次都随机初始化权重来训,观察整体
  2. 如果 mixup 和 clip_grad 不影响收敛精度来看实验结果的话,我觉得当目的是排除算子的时候,可以先关了,方便快速看结果。等确定完全没问题了。我们可以用一次随机权重,mixup clip_grad什么的都打开,完整的来对比一次
rentainhe commented 2 years ago

CIFAR100 完整训练实验

实验设置

实验结果

Model Optimizer Framework Acc@1 Log Swin-Tiny AdamW OneFlow 76.52% log Swin-Tiny AdamW Pytorch 77.31% log

简单结论

  • 相较于之前的分支精度差异从2%缩小到0.7%,这个分支进行的改动带来了积极结果 @MARD1NO @guo-ran

我的一些做实验设置建议:

  1. 固定权重,这里不是说oneflow导入torch的,而是说把torch的模型固定保留一份,后续实验都在这个权重的基础上来做。因为这两周不光是我自己做实验,郭冉昨天做实验,都有遇到某次实验loss曲线能一摸一样,有时候就会有些差距。 在固定权重的情况下做实验,我觉得是来帮助排查算子改动是否积极的。甚至能避免你们需要做多次实验取平均这个步骤(既然权重固定,其他 一致的情况下,多次做实验效果应该是一样的) 到后面要整体训的话,那可以每次都随机初始化权重来训,观察整体
  2. 如果 mixup 和 clip_grad 不影响收敛精度来看实验结果的话,我觉得当目的是排除算子的时候,可以先关了,方便快速看结果。等确定完全没问题了。我们可以用一次随机权重,mixup clip_grad什么的都打开,完整的来对比一次

可以的,我这边和德澎统一一个输入权重,上传到阿里云,你们后续对齐也可以用这个权重,统一成Tiny模型,mixup和clip-grad主要是为了和德澎的实验对应,看看结果

Ldpe2G commented 2 years ago

oss://oneflow-static/swin_init_models/ 上传了 swin_{tiny, small, base, large} 的 pytorch 的初始化模型,oneflow 和 torch 的都有保存,对应 swin_{tiny, small, base, large}_patch4_window7_224 这四个配置

@rentainhe @MARD1NO @guo-ran

kaijieshi7 commented 2 years ago

下面的代码测试了一下AdaptiveAvgPool1d。发现flow和torch的误差会有点大。

  1. 用mean代替了AdaptiveAvgPool1d,可以发现误差比直接用AdaptiveAvgPool1d小(但mean偶尔也会突然变大,概率没flow大)。具体代码在flow网络的forward函数里面注释了,直接修改就能用。
import oneflow as flow
import oneflow.nn as flownn
import torch as torch
import torch.nn as torchnn
import numpy as np

np.random.seed(1)
linear = np.random.rand(96, 96).astype('float32')
class N1(flownn.Module):
    def __init__(self):
        super(N1, self).__init__()
        self.reduction = flownn.Linear(768, 768, bias=False)
        self.reduction2 = flownn.Linear(768, 1000, bias=False)
        self.avgpool = flownn.AdaptiveAvgPool1d(1)
        flownn.init.constant_(self.reduction.weight.data, 0.5)
        flownn.init.constant_(self.reduction2.weight.data, 0.5)

    def forward(self, x):
        x = self.reduction(x)
        x = self.avgpool(x.transpose(1, 2))
        # x = flow.mean(x.transpose(1, 2), -1)
        x = flow.flatten(x, 1)
        x = self.reduction2(x)
        return x

class N2(torchnn.Module):
    def __init__(self):
        super(N2, self).__init__()
        self.reduction = torchnn.Linear(768, 768, bias=False)
        self.reduction2 = torchnn.Linear(768, 1000, bias=False)
        self.avgpool = torchnn.AdaptiveAvgPool1d(1)
        torchnn.init.constant_(self.reduction.weight.data, 0.5)
        torchnn.init.constant_(self.reduction2.weight.data, 0.5)

    def forward(self, x):
        x = self.reduction(x)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        x = self.reduction2(x)
        return x

#初始化网络
n_flow = N1().cuda()
n_torch = N2().cuda()

# 相同的损失函数
# loss_fn_FLOW = flownn.MSELoss()
loss_fn_FLOW = flownn.L1Loss()
# loss_fn_FLOW = flownn.BCELoss()

# loss_fn_torch = torchnn.MSELoss()
loss_fn_torch = torchnn.L1Loss()
# loss_fn_torch = torchnn.BCELoss()
# 相同的优化器
optimizer_flow = flow.optim.SGD(n_flow.parameters(), lr=0.01,  weight_decay=0)
optimizer_torch = torch.optim.SGD(n_torch.parameters(), lr=0.01,  weight_decay=0)

for i in range(10000):
    # 相同的输入
    n = np.random.rand(1, 49, 768).astype('float32')
    x_flow = flow.tensor(n).cuda()
    x_torch = torch.from_numpy(n).cuda()
    # 相同的label
    n_label = np.random.rand(1, 1000).astype('float32')
    label_flow = flow.tensor(n_label).cuda()
    label_torch = torch.from_numpy(n_label).cuda()
    # 查看norm的可学习参数
    # print(n_flow.norm.weight)
    # print(n_flow.norm.bias)
    optimizer_flow.zero_grad()
    y1 = n_flow(x_flow)
    loss_fn_FLOW(y1, label_flow).backward()

    # print(n_torch.norm.weight)
    # print(n_torch.norm.bias)
    optimizer_torch.zero_grad()
    y2 = n_torch(x_torch)
    loss_fn_torch(y2, label_torch).backward()
    # 查看梯度
    for i, j in zip(optimizer_flow.param_groups[0].parameters, optimizer_torch.param_groups[0]['params']):
        ii = i.grad
        jj = j.grad

    optimizer_flow.step()
    optimizer_torch.step()
    print(abs(y1.numpy()-y2.detach().cpu().numpy()).sum().item())
guo-ran commented 2 years ago

我跑了一下,确实差异好大,需要看看是不是有bug,这个op有没有可以换的其他版本op?比如调cudnn的实现的pooling op。

yuanms2 commented 2 years ago

swin transformer 里用了AdaptiveAvgPool1d 这个kernel吗,我们的单测为什么没有发现这个问题

guo-ran commented 2 years ago

每个数据上是微小的计算差异,这里print出来很大是因为print的sum(diff),应该是我们的pool操作和pytorch的操作的reduce顺序有区别,需要看看具体代码区别

MARD1NO commented 2 years ago

如果怀疑 AdaptivePool 有问题的话,我觉得可以用你们说的mean来都替换下,做个实验。

另外我可以用一段代码来解释下为啥都是求均值,不同结果是不一样的,以numpy为例子:

import numpy as np

x = np.random.randn(100).astype(np.float32)
out = np.mean(x)
np.random.shuffle(x)
out2 = np.mean(x)
np.random.shuffle(x)
out3 = np.mean(x)
print(out)
print(out2)
print(out3)

运行这段代码,可以看到每个out输出结果都是存在一定微小的差距的。而到 AdaptivePool 里,我们设置的线程块和线程和torch的不一样。每个线程做求和的时候,顺序都是不一样的,也就会造成 1e-4~1e-5 的差别,再累积起来就显得比较大。

再从整体框架来说,我,俊丞,郭冉都觉得 矩阵乘,卷积(这类直接调用cublas, cudnn的),elementwise计算部分 我们可以保证一模一样,但是涉及到reduce操作,比如说 layernorm的weight求反向梯度的时候,AdaptivePool,最后loss.sum(),这部分保证不了一模一样,我认为这不能算是算法差别,而是计算机浮点数的误差。

如果每个Kernel都要求小数点这样的话,我觉得不太现实,也没人手(说到底现在就我一个人)来去支持这么替换Kernel实现做实验

kaijieshi7 commented 2 years ago

有空的可以运行下面trauncnormal的代码,因为是正态分布,所以我用了很多数据在显卡上,2080ti单卡是能跑的,画图有点卡,我没化。

  1. 调整std的参数,oneflow的最大值最小是都是std参数的两倍。可以把of的std换成0.97,0.96。因为生成了很多很多的数据,所以是理论肯定有数据等于上下限2的(当std不是很小时)。
from timm.models.layers import trunc_normal_
import torch.nn as torchnn
import torch
import oneflow as flow
import oneflow.nn as flownn

data = torchnn.Parameter(torch.zeros((2 * 7 - 1) * (2 * 7 - 1), 1000000, device='cuda:0'))
trunc_normal_(data, std=0.99)
# plt.plot(relative_position_bias_table.detach().cpu().numpy())
# plt.show()
print(data.max())
print(data.min())
del data
torch.cuda.empty_cache()
data = flownn.Parameter(flow.zeros((2 * 7 - 1) * (2 * 7 - 1), 1000000, device='cuda:0', dtype=flow.float32))
data.trunc_normal_(std=0.9)
print(data.max())
print(data.min())
flownn.init.trunc_normal_(data, std=0.98)
print(data.max())
print(data.min())