MegEngine / MegDiffusion

MegEngine implementation of Diffusion Models.
Apache License 2.0
16 stars 0 forks source link

Gradient clipping issues in MegEngine v1.9.x #1

Closed ChaiByte closed 2 years ago

ChaiByte commented 2 years ago

Description

Training with a single GPU & using gradient clipping in this codebase will cause an error in MegEngine 1.9.x version. After 1 iteration with auto diff & parameter update, the next time the model do forward will break. Error message:

RuntimeError: assertion `filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4' failed at ../../../../../../imperative/src/impl/ops/convolution.cpp:61: megdnn::TensorLayout mgb::imperative::{anonymous}::convolution::do_shape_infer(const mgb::imperative::OpDef&, size_t, megdnn::TensorLayout, megdnn::TensorLayout)
extra message: bad filter ndim for dense convolution: spatial_ndim=2 filter_ndim=0

Here is the simplest example to reproduce this problem:

import megengine
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
import megengine.autodiff as autodiff

megengine.async_level = 0

class SimpleModel(M.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.conv1 = M.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.conv2 = M.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.nn.interpolate(x, scale_factor=1, mode="nearest")
        x = self.conv2(x)
        return x

if __name__ == "__main__":
    x = F.ones((1, 1, 2, 2))
    model = SimpleModel(in_ch = 1)

    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    gm = autodiff.GradManager()
    gm.attach(model.parameters())

    with gm:
        loss = model(x) + 0
        gm.backward(loss)

    optim.clip_grad_norm(model.parameters(), max_norm=1.)
    optimizer.step()
    y = model(x)

Workaround

Issue Track

This problem was fixed in https://github.com/MegEngine/MegEngine/commit/df5ebd3da7495b8eb8f079651dbe980c5f4d7d37 so you can wait for the release of MegEngine v1.10 or build MegEngine dev latest than this commit from the source.

ChaiByte commented 2 years ago

Python traceback messages show that dnn apply op conv1(x) failed:

In MegEngine/MegDNN, it's common to do shape_infer when we hope to do something like (output,) = apply(op, inp, weight) and dispatch the kernel to the MegDNN computing library. With the input descriptions, we could (not always work) infer the information about the output such as shape. It's convenient in some situations, for example, if we need to get the shape of the convolution filter in the ith layer, we don't need to input the data and execute the corresponding code until the filter tensor was calculated -- try to infer that information in advance. Then get them directly when needed.

TensorShape ChannelImpl::get_shape(Handle handle) {
    auto info = reinterpret_cast<TensorInfo*>(handle);
    if (info->desc.layout.ndim != 0) {
        return info->desc.layout;
    }
    TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
    return ret;
}

DType ChannelImpl::get_dtype(Handle handle) {
    auto info = reinterpret_cast<TensorInfo*>(handle);
    MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
    auto ret = info->desc.layout.dtype;
    return ret;
}

In MegDNN, filter_ndim=0 means the filter's ndim is unknown (or we say it can not be inferred in advance). But it SHOULD NOT break our programs because we will be able to get this shape information when needed computation is done. So https://github.com/MegEngine/MegEngine/commit/df5ebd3da7495b8eb8f079651dbe980c5f4d7d37 is the solution.

The story is not ending...

It seems that the problem has been solved. But if you are still confused about why those workarounds mentioned above are valid, you can try debugging the example program. The RuntimeError is raised from MegDNN, you need gdb, so let's do it!

Build MegEngine with -DCMAKE_BUILD_TYPE=Debug from the source. Run gdb python3, catch throw then run example.py (might wait for seconds depending on the machine):

Who lost the shape info?

Verification

We can insert a reshape op between interpolation and conv2 ops as follows:

class SimpleModel(M.Module):
    #...
    def forward(self, x):
        x = self.conv1(x)
        x = F.nn.interpolate(x, scale_factor=1, mode="nearest").reshape(1, 1, 2, 2)
        x = self.conv2(x)

Run the program again, congratulations! Everything is fine.