Closed kohillyang closed 3 years ago
Hi @kohillyang. Thanks for reporting the issue. Would you mind creating a small reproducible example? Others who help would still need to trigger this issue first in order to debug it.
@szha The following codes can reproduce the above error.
from __future__ import print_function
import mxnet as mx
import mxnet.autograd as ag
import numpy as np
import gluoncv
class resnet(mx.gluon.nn.HybridBlock):
def __init__(self):
super(resnet, self).__init__()
self.feat = gluoncv.model_zoo.resnet50_v1b(pretrained=False)
def hybrid_forward(self, F, x):
input = F.transpose(x, (0, 3, 1, 2))
x = input / 255.0
x = self.feat.conv1(x)
x = self.feat.bn1(x)
x = self.feat.relu(x)
x = self.feat.maxpool(x)
res2 = self.feat.layer1(x)
res3 = self.feat.layer2(res2)
res4 = self.feat.layer3(res3)
res5 = self.feat.layer4(res4)
return res5
def train_net():
mx.random.seed(3)
np.random.seed(3)
ctx_list = [mx.gpu(0)]
net = resnet()
net.initialize()
net.collect_params().reset_ctx(list(set(ctx_list)))
if True:
from mxnet.contrib import amp
amp.init()
net.cast("float16")
# net.collect_params('.*batchnorm.*').setattr('dtype', 'float32')
trainer = mx.gluon.Trainer(
net.collect_params(), # fix batchnorm, fix first stage, etc...
'sgd',
{'wd': 1e-4,
'momentum': .9,
'clip_gradient': None,
'lr_scheduler': None,
'multi_precision': True,
},
update_on_kvstore=(False if True else None), kvstore=mx.kvstore.create('local')
)
if True:
amp.init_trainer(trainer)
with ag.record():
data = mx.nd.zeros(shape=(1, 368, 368, 3), ctx=ctx_list[0])
fpn_predictions = net(data)
preds = mx.nd.concat(*[x.reshape((0, 0, -1)) for x in fpn_predictions], dim=2)
with amp.scale_loss(preds.sum(), trainer) as scaled_losses:
scaled_losses.backward()
trainer.step(1, ignore_stale_grad=True)
if __name__ == '__main__':
train_net()
Can I get some advice? I have no idea how to solve this problem. Thanks.
so the error is saying that it's expecting float32 while getting float16. you may want to try to use NaiveEngine to pinpoint where this came from: https://mxnet.apache.org/api/dev-guide/debugging_and_performance_optimization_tips
Type code can be found here: https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L327-L342
/*! \brief data type flag */
enum TypeFlag {
kFloat32 = 0,
kFloat64 = 1,
kFloat16 = 2,
kUint8 = 3,
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
kBool = 7,
kInt16 = 8,
kUint16 = 9,
kUint32 = 10,
kUint64 = 11,
kBfloat16 = 12
};
I found the reason is that amp.init should be called before the creation of the network, because it will patch some functions in the namespaces of mx.nd and mx.sym. And since the functions have already been patched, the network does not need to cast to float16.
However, in https://github.com/dmlc/gluon-cv/blob/e3513064244f3f987f699ac43781d40ad01e144a/scripts/detection/faster_rcnn/train_faster_rcnn.py#L655, the network is cast to float16 when using amp, so I 'm not sure whether the casting is needed.
I found the reason is that amp.init should be called before the creation of the network
looks like at minimum we need to document this, or better yet it would be great if AMP can handle the necessary changes transparently. @mk-61 it would be great if you could help take this into account in https://github.com/apache/incubator-mxnet/issues/18896
And it seems that if amp is used, the trainer is not able to load its trainer states saved previously.
For loading states - have you tried after https://github.com/apache/incubator-mxnet/pull/18959 ? It was supposed to fix this.
On Tue, Aug 25, 2020 at 9:42 PM kohillyang notifications@github.com wrote:
And it seems that if amp is used, the trainer is not able to load its trainer states saved previously.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680604104, or unsubscribe https://github.com/notifications/unsubscribe-auth/ANQG5UTW4JI6PYZT7QHCCDTSCSHBZANCNFSM4P22FKJQ .
@mk-61 It works, thank you very much.
Closing based on https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680788665, and the warning for too late inintialization of AMP was introduced in #19036 .
Description
Hello, I'm trying to reproduce FCOS. In order to set a larger batch size, I'm trying to train my codes with FP16. I modified my codes based on https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py, but after training started, I got the following error. One possible reason is that I have used some Operators that AMP does not support, but all contrib operators I used are Deformable Convolution and BilinearResize2D, and I'm not sure if they are supported by AMP.
https://github.com/kohillyang/mx-detection/blob/master/scripts/train_fcos.py are my codes, Thanks.
Error Message
To Reproduce
(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
Steps to reproduce
(Paste the commands you ran that produced the error.)
1. 2.
What have you tried to solve it?
1. 2.
Environment
We recommend using our script for collecting the diagnositc information. Run the following command and paste the outputs below: