facebookresearch / WSL-Images

Weakly Supervised Learning On Images
599 stars 63 forks source link

Significant numerical differences with torch.amp.autocast() compared with stock pretrained resnext101 #17

Open farleylai opened 4 years ago

farleylai commented 4 years ago

When making inferences with torch.amp.autocast(), the forward results show significant numerical differences compared with pretrained resnext101_32x8d from torchvision as the sample outputs in the following given the same input batch:

Output from WSL pretrained resnext101_32x8d_wsl shows significant differences:

actual(w/ amp.autocast) = tensor([ 1.0162e-01,  3.0859e-01, -1.9760e-03,  3.7750e-02, -4.5996e-01,
        -7.2510e-01, -8.7402e-02, -9.4727e-01...698e-02, -1.4392e-01,
        -2.1533e-01, -5.7666e-01, -8.1787e-02,  1.8103e-01,  2.3596e-01],
expected(w/o amp.autocast) = tensor([ 1.2948e-01,  4.0339e-01,  6.4677e-02,  6.7963e-02, -3.6953e-01,
        -6.0408e-01, -1.5742e-01, -8.2637e-01...613e-02, -1.7330e-01,
        -2.1253e-01, -5.2314e-01, -1.2327e-01,  1.0499e-01,  1.7262e-01],

Output from torchvision pretrained resnext101_32x8d shows approximate numerical values:

actual(w/ amp.autocast) = tensor([-2.9844e+00, -6.8945e-01,  5.9668e-01, -1.2510e+00, -7.2168e-01,
        -2.1992e+00, -1.2686e+00, -5.0879e-01...953e+00, -4.4453e+00,
        -4.8984e+00, -3.2617e+00, -2.6641e+00, -2.2344e+00,  5.4922e+00],
expected(w/o amp.autocast) = tensor([-2.9887e+00, -6.8953e-01,  5.9514e-01, -1.2496e+00, -7.2139e-01,
        -2.2008e+00, -1.2737e+00, -5.1238e-01...971e+00, -4.4485e+00,
        -4.9002e+00, -3.2653e+00, -2.6683e+00, -2.2359e+00,  5.5001e+00],

Is it because the pretrained resnext101 from torchvision is already trained in mixed precision or something else? Any clarifications would be appreciated.

PS: sample pytest code to load the models and run the tests:

import torch as th

def batch_size():
    return 2

def shape():
    return 3, 720, 1280

def dev():
    return th.device('cuda') if th.cuda.is_available() else torch.device('cpu')

def batch(batch_size, shape):
    return th.rand(batch_size, *shape)

def x101_32x8d(dev):
    from torchvision.models.resnet import _resnet
    from torchvision.models.resnet import Bottleneck
    from torchvision.ops.misc import FrozenBatchNorm2d
    kwargs = {}
    frozen = True
    kwargs['groups'] = gs = kwargs.get('groups', 32)
    kwargs['width_per_group'] = gw = kwargs.get('width_per_group', 8)
    kwargs['norm_layer'] = kwargs.get('norm_layer', FrozenBatchNorm2d if frozen else None)
    arch = f"resnext101_{gs}x{gw}d"
    model = _resnet(arch, Bottleneck, [3, 4, 23, 3], True, True, **kwargs)
    return model

def x101_32x8d_wsl(dev):
    from torchvision.ops.misc import FrozenBatchNorm2d
    kwargs = {}
    frozen = True
    kwargs['groups'] = gs = kwargs.get('groups', 32)
    kwargs['width_per_group'] = gw = kwargs.get('width_per_group', 8)
    kwargs['norm_layer'] = kwargs.get('norm_layer', FrozenBatchNorm2d if frozen else None)
    model = th.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl', **kwargs)
    return model

@pytest.mark.parametrize("B", [2])
def test_x101_amp(benchmark, x101_32x8d, dev, batch, B):
    model = x101_32x8d
    with th.no_grad():
        with th.cuda.amp.autocast(enabled=False):
            outputs_fp32 = model(batch[:B].to(dev)).float()
        with th.cuda.amp.autocast():
            outputs_amp = model(batch[:B].to(dev)).float()

    for i, (output_fp32, output_amp) in enumerate(zip(outputs_fp32, outputs_amp)):
        logging.info(f"output[{i}] shape={tuple(output_fp32.shape)}, norm_fp32={output_fp32.norm()}, norm_amp={output_amp.norm()}")
        th.testing.assert_allclose(output_amp, output_fp32, rtol=1e-03, atol=3e-04)

@pytest.mark.parametrize("B", [2])
def test_x101_wsl_amp(benchmark, x101_32x8d_wsl, dev, batch, B):
    model = x101_32x8d_wsl
    with th.no_grad():
        with th.cuda.amp.autocast(enabled=False):
            outputs_fp32 = model(batch[:B].to(dev)).float()
        with th.cuda.amp.autocast():
            outputs_amp = model(batch[:B].to(dev)).float()

    for i, (output_fp32, output_amp) in enumerate(zip(outputs_fp32, outputs_amp)):
        logging.info(f"output[{i}] shape={tuple(output_fp32.shape)}, norm_fp32={output_fp32.norm()}, norm_amp={output_amp.norm()}")
        th.testing.assert_allclose(output_amp, output_fp32, rtol=1e-03, atol=3e-04)