PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)
http://www.paddlepaddle.org/
Apache License 2.0
21.79k stars 5.47k forks source link

paddle与torch提供的conv2d层精度无法对齐 #52599

Open xxxqhloveu opened 1 year ago

xxxqhloveu commented 1 year ago

请提出你的问题 Please ask your question

在输入数据维度比较高的情况下,torch.nn.Conv2d与paddle.nn.Conv2D的精度差过大。使用https://github.com/PaddlePaddle/models/tree/release/2.3/tutorials/reprod_log 中提到的精度对齐工具。精度要求1e-5,而一下测试结果位0.000186...(1e-4)。

import torch
import paddle
import numpy as np
from reprod_log.reprod_log import ReprodDiffHelper
def result_diff(info1, info2):
    diff_helper = ReprodDiffHelper()
    # compare result and produce log
    diff_helper.compare_info(info1, info2)
    diff_helper.report(
        path="/opt/forward_diff.log", diff_threshold=1e-5)

def compulte_conv2d():
    paddle.disable_static()
    # input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
    input = paddle.rand((1, 2048, 28, 50))
    conv = paddle.nn.Conv2D(2048, 256, 1)
    conv_state_dict = conv.state_dict()
    torch_state_dict = {}
    for item in conv_state_dict:
        torch_state_dict[item] = torch.tensor(conv_state_dict[item].numpy())
    output = conv(input)
    
    conv_torch = torch.nn.Conv2d(2048, 256, 1)
    with torch.no_grad():
        for name, param in conv_torch.named_parameters():
            if name in torch_state_dict.keys():
                param.copy_(torch_state_dict[name])
            else:
                pass
    input_torch = torch.tensor(input.detach().numpy())
    output_torch = conv_torch(input_torch)
    result_diff({"logits":output.numpy()}, {"logits":output_torch.detach().numpy()})
LielinJiang commented 1 year ago

conv_torch = torch.nn.Conv2d(2048, 256, 1) torch默认会把创建的layer和tensor放在cpu,你加上.cuda()试一下

xxxqhloveu commented 1 year ago

@paddle.no_grad()
@torch.no_grad()
def compulte_conv2d_t2p():
    input_np = np.random.random((1, 2048, 28, 50))
    # input_np = np.load("/data/xiaoqihang/myproject/race/paper/SPTS_paddle/save_test_result/diff/b_output.npy")
    input_torch = torch.tensor(input_np, dtype=torch.float32)
    conv_torch = torch.nn.Conv2d(2048, 256, 1)
    device = torch.device('cuda')
    input_torch = input_torch.to(device)
    conv_torch = conv_torch.to(device)
    output_torch = conv_torch(input_torch)
    torch_state_dict = conv_torch.state_dict()

    conv_paddle = paddle.nn.Conv2D(2048, 256, 1)
    paddle_state_dict = conv_paddle.state_dict()
    for name, param in conv_torch.named_parameters():
        paddle_state_dict[name] = paddle.to_tensor(param.detach().cpu().numpy()).astype("float32")
    conv_paddle.set_state_dict(paddle_state_dict)
    paddle_state_dict = conv_paddle.state_dict()
    # input_paddle = paddle.to_tensor(input_torch.detach().cpu().numpy())
    input_paddle = paddle.to_tensor(input_np).astype("float32")
    paddle.set_device("gpu")
    output_paddle = conv_paddle(input_paddle)
    result_diff({"logits":input_paddle.numpy()}, {"logits":input_torch.detach().cpu().numpy()})
    result_diff({"logits":output_paddle.numpy()}, {"logits":output_torch.detach().cpu().numpy()})

我用gpu也仍是不行,而且和输入数据有关。这次是精度差是7e-5,保存的某个前传数据精度差达到1e-4

xxxqhloveu commented 1 year ago

conv_torch = torch.nn.Conv2d(2048, 256, 1) torch默认会把创建的layer和tensor放在cpu,你加上.cuda()试一下

我把gpu的代码也贴上来了(在下面的评论中),gpu也不行。rand数据精度差是7e-5,保存的某个前传数据精度差达到1e-4。我有截图但这里不能放。3090的卡

Yang-Changhui commented 2 months ago

conv_torch = torch.nn.Conv2d(2048, 256, 1) torch默认会把创建的layer和tensor放在cpu,你加上.cuda()试一下

@LielinJiang 你好,我使用paddle版本为2.6,paddle运行设备为V100,torch加载在3060,输入为[16,256,16,16],卷积层为[256,32,1,1],相同的输入,相同的权重,使用reprod_log 误差为1e-4,最大值为:2.5983,最小值为:-2.5591