microsoft / CNTK

Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit
https://docs.microsoft.com/cognitive-toolkit/
Other
17.52k stars 4.28k forks source link

debug_model does not work with batch normalization #2233

Open n17s opened 7 years ago

n17s commented 7 years ago

Continuing on the debug_model issues (c.f. #2232 and #2231) debugging a model with batchnorm doesn't work. The repro is:

>>> x = C.sequence.input_variable(5)
>>> y = C.layers.Sequential([C.layers.BatchNormalization()])(x)
>>> y.eval({x:x0})
[array([[    0.        ,   316.22775269,   632.45550537,   948.68322754,
         1264.91101074],
       [ 1581.13879395,  1897.36645508,  2213.59423828,  2529.82202148,
         2846.04980469],
       [ 3162.27758789,  3478.50537109,  3794.73291016,  4110.9609375 ,
         4427.18847656],
       [ 4743.41650391,  5059.64404297,  5375.87158203,  5692.09960938,
         6008.32714844],
       [ 6324.55517578,  6640.78271484,  6957.01074219,  7273.23828125,
         7589.46582031],
       [ 7905.69384766,  8221.921875  ,  8538.14941406,  8854.37695312,
         9170.60449219]], dtype=float32)]
>>> w = C.debugging.debug_model(y)
>>> w.eval({x:x0})
Validating --> BatchNormalization5539 = BatchNormalization (UserDefinedFunction5518, UserDefinedFunction5515, UserDefinedFunction5509, UserDefinedFunction5506, UserDefinedFunction5512, UserDefinedFunction5503) : [5 x *23], [5], [5], [5], [5], [] -> [] FAILED
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\nikosk\Documents\mastercntk\CNTK\bindings\python\cntk\ops\functions.py", line 631, in eval
    _, output_map = self.forward(arguments, outputs, device=device, as_numpy=as_numpy)
  File "C:\Users\nikosk\Documents\mastercntk\CNTK\bindings\python\cntk\internal\swig_helper.py", line 69, in wrapper
    result = f(*args, **kwds)
  File "C:\Users\nikosk\Documents\mastercntk\CNTK\bindings\python\cntk\ops\functions.py", line 765, in forward
    keep_for_backward)
  File "C:\Users\nikosk\Documents\mastercntk\CNTK\bindings\python\cntk\cntk_py.py", line 1860, in _forward
    return _cntk_py.Function__forward(self, *args)
ValueError: Node 'BatchNormalization5539' (BatchNormalization operation): Inputs [3..6] must be learnable parameters.

[CALL STACK]
    > Microsoft::MSR::CNTK::InvalidArgument<wchar_t const * __ptr64,int,int>
    - Microsoft::MSR::CNTK::BatchNormalizationNode<float>::  Validate
    - Microsoft::MSR::CNTK::ComputationNetwork::  ValidateNode
    - Microsoft::MSR::CNTK::ComputationNetwork::  ValidateNodes
    - Microsoft::MSR::CNTK::ComputationNetwork::  ValidateNetwork
    - Microsoft::MSR::CNTK::ComputationNetwork::  CompileNetwork
    - CNTK::CompositeFunction::CreateComputationNetwork<float>
    - CNTK::CompositeFunction::GetComputationNetwork<float>
    - CNTK::CompositeFunction::  Forward
    - CNTK::Function::  Forward
    - _wrap_Function__forward__SWIG_1
    - _wrap_Function__forward
    - PyMapping_Length (x4)
tymokvo commented 6 years ago

Was this ever resolved? I am having the same issue in 2.6. The model executes via a Trainer object during a training routine. But on trying to evaluate the forward pass for validation, BatchNorm produces NaN values. After running with debug_model I have the same error as above.