apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.77k stars 6.79k forks source link

Got "kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2" when training with amp. #18902

Closed kohillyang closed 3 years ago

kohillyang commented 4 years ago

Description

Hello, I'm trying to reproduce FCOS. In order to set a larger batch size, I'm trying to train my codes with FP16. I modified my codes based on https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py, but after training started, I got the following error. One possible reason is that I have used some Operators that AMP does not support, but all contrib operators I used are Deformable Convolution and BilinearResize2D, and I'm not sure if they are supported by AMP.

https://github.com/kohillyang/mx-detection/blob/master/scripts/train_fcos.py are my codes, Thanks.

Error Message

  0%|          | 0.00/19.5k [00:04<?, ?it/s]
Traceback (most recent call last):
  File "/data2/kohill/mx-detection/scripts/train_fcos.py", line 584, in <module>
    main()
  File "/data2/kohill/mx-detection/scripts/train_fcos.py", line 512, in main
    train_net(config)
  File "/data2/kohill/mx-detection/scripts/train_fcos.py", line 399, in train_net
    trainer.step(batch_size)
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/gluon/trainer.py", line 335, in step
    self._update(ignore_stale_grad)
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/contrib/amp/amp.py", line 321, in new_update
    self._old_update(ignore_stale_grad)
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/gluon/trainer.py", line 437, in _update
    updater(i, w, g)
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/optimizer/optimizer.py", line 2071, in __call__
    states)
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/contrib/amp/amp.py", line 313, in new_update_multi_precision
    if not skip_update():
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/contrib/amp/loss_scaler.py", line 64, in wait_and_update
    self._has_overflow = not bool(self.output.asnumpy())
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2535, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/base.py", line 255, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [16:57:52] include/mxnet/././tensor_blob.h:215: Check failed: mshadow::DataType<DType>::kFlag == type_flag_: TBlob.get_with_shape: data type do not match specified type.Expected: 0 v.s. given 2
Stack trace:
  [bt] (0) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x6b8b5b) [0x7ffaebc7cb5b]
  [bt] (1) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4295e78) [0x7ffaef859e78]
  [bt] (2) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x42c01d1) [0x7ffaef8841d1]
  [bt] (3) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContext)#1}::operator()(mxnet::RunContext) const+0x307) [0x7ffaeee5d377]
  [bt] (4) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37b68d4) [0x7ffaeed7a8d4]
  [bt] (5) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c4961) [0x7ffaeed88961]
  [bt] (6) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c7ea0) [0x7ffaeed8bea0]
  [bt] (7) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c8136) [0x7ffaeed8c136]
  [bt] (8) /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x37c3114) [0x7ffaeed87114]

Process finished with exit code 1

To Reproduce

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

1. 2.

What have you tried to solve it?

1. 2.

Environment

We recommend using our script for collecting the diagnositc information. Run the following command and paste the outputs below:

curl --retry 10 -s https://raw.githubusercontent.com/dmlc/gluon-nlp/master/tools/diagnose.py | python

----------Python Info----------
Version      : 3.6.5
Compiler     : GCC 7.2.0
Build        : ('default', 'Apr 29 2018 16:14:56')
Arch         : ('64bit', '')
------------Pip Info-----------
Version      : 10.0.1
Directory    : /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/pip
----------MXNet Info-----------
/data/kohill/anaconda3_jyw/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Version      : 1.6.0
Directory    : /data/kohill/anaconda3_jyw/lib/python3.6/site-packages/mxnet
Num GPUs     : 4
Commit Hash   : 6eec9da55c5096079355d1f1a5fa58dcf35d6752
----------System Info----------
Platform     : Linux-4.15.0-107-generic-x86_64-with-debian-stretch-sid
system       : Linux
node         : ubuntu
release      : 4.15.0-107-generic
version      : #108~16.04.1-Ubuntu SMP Fri Jun 12 02:57:13 UTC 2020
----------Hardware Info----------
machine      : x86_64
processor    : x86_64
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                48
On-line CPU(s) list:   0-47
Thread(s) per core:    2
Core(s) per socket:    12
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 63
Model name:            Intel(R) Xeon(R) CPU E5-2680 v3 @ 2.50GHz
Stepping:              2
CPU MHz:               1735.211
CPU max MHz:           3300.0000
CPU min MHz:           1200.0000
BogoMIPS:              5001.81
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              30720K
NUMA node0 CPU(s):     0-11,24-35
NUMA node1 CPU(s):     12-23,36-47
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 10.0385 sec, LOAD: 8.7852 sec.
Timing for GluonNLP GitHub: https://github.com/dmlc/gluon-nlp, DNS: 0.0294 sec, LOAD: 9.1138 sec.
Timing for GluonNLP: http://gluon-nlp.mxnet.io, DNS: 5.0335 sec, LOAD: 22.5200 sec.
Timing for D2L: http://d2l.ai, DNS: 0.0282 sec, LOAD: 2.4723 sec.
Timing for D2L (zh-cn): http://zh.d2l.ai, DNS: 0.0285 sec, LOAD: 0.4769 sec.
Timing for FashionMNIST: https://repo.mxnet.io/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 11.8055 sec, LOAD: 16.3673 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0289 sec, LOAD: 10.1921 sec.
Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.02894139289855957 sec.
szha commented 4 years ago

Hi @kohillyang. Thanks for reporting the issue. Would you mind creating a small reproducible example? Others who help would still need to trigger this issue first in order to debug it.

kohillyang commented 4 years ago

@szha The following codes can reproduce the above error.

from __future__ import print_function

import mxnet as mx
import mxnet.autograd as ag
import numpy as np
import gluoncv

class resnet(mx.gluon.nn.HybridBlock):
    def __init__(self):
        super(resnet, self).__init__()
        self.feat = gluoncv.model_zoo.resnet50_v1b(pretrained=False)

    def hybrid_forward(self, F, x):
        input = F.transpose(x, (0, 3, 1, 2))
        x = input / 255.0
        x = self.feat.conv1(x)
        x = self.feat.bn1(x)
        x = self.feat.relu(x)
        x = self.feat.maxpool(x)

        res2 = self.feat.layer1(x)
        res3 = self.feat.layer2(res2)
        res4 = self.feat.layer3(res3)
        res5 = self.feat.layer4(res4)

        return res5

def train_net():
    mx.random.seed(3)
    np.random.seed(3)

    ctx_list = [mx.gpu(0)]
    net = resnet()
    net.initialize()
    net.collect_params().reset_ctx(list(set(ctx_list)))
    if True:
        from mxnet.contrib import amp
        amp.init()
        net.cast("float16")
        # net.collect_params('.*batchnorm.*').setattr('dtype', 'float32')

    trainer = mx.gluon.Trainer(
        net.collect_params(),  # fix batchnorm, fix first stage, etc...
        'sgd',
        {'wd': 1e-4,
         'momentum': .9,
         'clip_gradient': None,
         'lr_scheduler': None,
         'multi_precision': True,
         },
        update_on_kvstore=(False if True else None), kvstore=mx.kvstore.create('local')
    )
    if True:
        amp.init_trainer(trainer)

    with ag.record():
        data = mx.nd.zeros(shape=(1, 368, 368, 3), ctx=ctx_list[0])
        fpn_predictions = net(data)
        preds = mx.nd.concat(*[x.reshape((0, 0, -1)) for x in fpn_predictions], dim=2)
        with amp.scale_loss(preds.sum(), trainer) as scaled_losses:
            scaled_losses.backward()
    trainer.step(1, ignore_stale_grad=True)

if __name__ == '__main__':
    train_net()
kohillyang commented 4 years ago

Can I get some advice? I have no idea how to solve this problem. Thanks.

szha commented 4 years ago

so the error is saying that it's expecting float32 while getting float16. you may want to try to use NaiveEngine to pinpoint where this came from: https://mxnet.apache.org/api/dev-guide/debugging_and_performance_optimization_tips

Type code can be found here: https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/base.h#L327-L342

/*! \brief data type flag */
enum TypeFlag {
  kFloat32 = 0,
  kFloat64 = 1,
  kFloat16 = 2,
  kUint8 = 3,
  kInt32 = 4,
  kInt8  = 5,
  kInt64 = 6,
  kBool = 7,
  kInt16 = 8,
  kUint16 = 9,
  kUint32 = 10,
  kUint64 = 11,
  kBfloat16 = 12
};
kohillyang commented 4 years ago

I found the reason is that amp.init should be called before the creation of the network, because it will patch some functions in the namespaces of mx.nd and mx.sym. And since the functions have already been patched, the network does not need to cast to float16.

However, in https://github.com/dmlc/gluon-cv/blob/e3513064244f3f987f699ac43781d40ad01e144a/scripts/detection/faster_rcnn/train_faster_rcnn.py#L655, the network is cast to float16 when using amp, so I 'm not sure whether the casting is needed.

szha commented 4 years ago

I found the reason is that amp.init should be called before the creation of the network

looks like at minimum we need to document this, or better yet it would be great if AMP can handle the necessary changes transparently. @mk-61 it would be great if you could help take this into account in https://github.com/apache/incubator-mxnet/issues/18896

kohillyang commented 4 years ago

And it seems that if amp is used, the trainer is not able to load its trainer states saved previously.

mk-61 commented 4 years ago

For loading states - have you tried after https://github.com/apache/incubator-mxnet/pull/18959 ? It was supposed to fix this.

On Tue, Aug 25, 2020 at 9:42 PM kohillyang notifications@github.com wrote:

And it seems that if amp is used, the trainer is not able to load its trainer states saved previously.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680604104, or unsubscribe https://github.com/notifications/unsubscribe-auth/ANQG5UTW4JI6PYZT7QHCCDTSCSHBZANCNFSM4P22FKJQ .

kohillyang commented 4 years ago

@mk-61 It works, thank you very much.

ptrendx commented 3 years ago

Closing based on https://github.com/apache/incubator-mxnet/issues/18902#issuecomment-680788665, and the warning for too late inintialization of AMP was introduced in #19036 .