apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

Dropout inconsistency bug #16705

Open sxjscience opened 5 years ago

sxjscience commented 5 years ago

In the following script, we should obtain the same dropout mask but currently the result is related to nrepeat. Note that I've turned off cudnn dropout by setting cudnn_off=True.

import mxnet as mx
import numpy as np
import random
from numpy.testing import assert_allclose

base_y_np = None

for nrepeat in [1, 2, 3, 4]:
    seed = 123
    mx.random.seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    x = mx.nd.ones((3, 3), ctx=mx.gpu())
    for _ in range(nrepeat):
        y = mx.nd.Dropout(x, cudnn_off=True)
    with mx.autograd.record():
        y = mx.nd.Dropout(x, cudnn_off=True)
        y_np = y.asnumpy()
    if base_y_np is None:
        base_y_np = y_np
    else:
        assert_allclose(base_y_np, y_np)

Output:

Not equal to tolerance rtol=1e-07, atol=0

Mismatch: 55.6%
Max absolute difference: 2.
Max relative difference: 1.
 x: array([[0., 2., 0.],
       [0., 0., 2.],
       [0., 2., 0.]], dtype=float32)
 y: array([[2., 2., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

If we set the nrepeat to be the same value, the result is consistent

import mxnet as mx
import numpy as np
import random
from numpy.testing import assert_allclose

base_y_np = None
ctx = mx.gpu()

for nrepeat in [3, 3, 3, 3]:
    seed = 123
    mx.random.seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    x = mx.nd.ones((3, 3), ctx=ctx)
    for _ in range(nrepeat):
        y = mx.nd.Dropout(x, cudnn_off=True)
    with mx.autograd.record():
        y = mx.nd.Dropout(x, cudnn_off=True)
        y_np = y.asnumpy()
    if base_y_np is None:
        base_y_np = y_np
    else:
        assert_allclose(base_y_np, y_np)
sxjscience commented 5 years ago

Also, I've confirmed that CPU-side does not have this problem.

DickJC123 commented 5 years ago

What behavior do we expect from a model that has two Dropouts, where no seeds have been set explicitly in advance? Are the dropout patterns identical or different?

If the answer is 'different', then I would think that by setting the seeds in advance, the two-Dropout model would then have repeatable behavior, but the Dropouts would continue to be different.

Also, feel free @sxjscience to chime in on the discussion of PR https://github.com/apache/incubator-mxnet/pull/16532.

sxjscience commented 5 years ago

@DickJC123 The answer should be different because these two dropouts should share the same internal random number generator and the random state will be updated accordingly.

For the inconsistency bug mentioned in this issue, it's not exactly related to the seeding problem.

For example, consider the following script:

import mxnet as mx
mx.random.seed(123)
x = mx.nd.ones((10, 10))

y = mx.nd.Dropout(x, cudnn_off=True)
with mx.autograd.record():
   y = mx.nd.Dropout(x, cudnn_off=True)

The first y = mx.nd.Dropout(x, cudnn_off=True) is not surrounded by autograd, and should not update the random state. However, in the current implementation (https://github.com/apache/incubator-mxnet/blob/bb6305d11d4383af2022e53ad94d6a1d5d93cb00/src/operator/nn/dropout-inl.h#L495), the rand() function will still be called when the node is constructed.. Thus, running y = mx.nd.Dropout(x, cudnn_off=True) outside the train loop will still interfere the random state.

This means, the following two code snippets will obtain different results:

y = mx.nd.Dropout(x, cudnn_off=True) with mx.autograd.record(): y = mx.nd.Dropout(x, cudnn_off=True) print(y)

[[0. 2. 0.] [0. 0. 2.] [0. 2. 0.]] <NDArray 3x3 @gpu(0)>


- Case 2
```python
import mxnet as mx
mx.random.seed(123)
x = mx.nd.ones((3, 3), ctx=mx.gpu())

with mx.autograd.record():
   y = mx.nd.Dropout(x, cudnn_off=True)
print(y)
[[0. 0. 2.]
 [0. 0. 2.]
 [0. 2. 0.]]
<NDArray 3x3 @gpu(0)>
sxjscience commented 5 years ago

@DickJC123 You may see that I've manually set cudnn_off=True. Also, I think https://github.com/apache/incubator-mxnet/pull/16532 will solve this problem.

xidulu commented 5 years ago

Clearly, dropout in inference mode affects the random state:

>>> mx.random.seed(123)
>>> mx.nd.Dropout(x, cudnn_off=True)

[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
<NDArray 3x3 @gpu(0)>
>>> mx.random.uniform(shape=(2,2),ctx=mx.gpu(0))

[[0.6512425  0.11220306]
 [0.86499107 0.68052745]]
<NDArray 2x2 @gpu(0)>
>>> mx.random.seed(123)
>>> mx.random.uniform(shape=(2,2),ctx=mx.gpu(0))

[[0.9423294  0.68506277]
 [0.19981462 0.60299706]]
<NDArray 2x2 @gpu(0)>
sxjscience commented 5 years ago

With the help of @xidulu , we have located the root cause of the issue:

The bug is triggered because we have multiple parallel GPU random resources: https://github.com/apache/incubator-mxnet/blob/c583e44816a5e383493f35e69daaa92a47e40e39/src/resource.cc#L93-L94

When we create a new Dropout Node, we will attach a random resource to the node: https://github.com/apache/incubator-mxnet/blob/c583e44816a5e383493f35e69daaa92a47e40e39/src/operator/nn/dropout.cc#L148-L164

Since there are multiple random resources, we select one in a round-robin fashion. Each resource has it's specific seed, which results in the inconsistent behavior. https://github.com/apache/incubator-mxnet/blob/c583e44816a5e383493f35e69daaa92a47e40e39/src/resource.cc#L344-L351

The simplest fix is to use 1 GPU random generator. Thus, setting os.environ['MXNET_GPU_PARALLEL_RAND_COPY'] = '1' will fix this problem:


import os

os.environ['MXNET_GPU_PARALLEL_RAND_COPY'] = '1'

import mxnet as mx
import numpy as np
import random
from numpy.testing import assert_allclose

base_y_np = None

for nrepeat in [1, 2, 3, 4]:
    seed = 123
    mx.random.seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    x = mx.nd.ones((3, 3), ctx=mx.gpu())
    for _ in range(nrepeat):
        y = mx.nd.Dropout(x, cudnn_off=True)
    with mx.autograd.record():
        y = mx.nd.Dropout(x, cudnn_off=True)
        y_np = y.asnumpy()
    if base_y_np is None:
        base_y_np = y_np
    else:
        assert_allclose(base_y_np, y_np)
samskalicky commented 5 years ago

@zachgk assign @larroy

zachgk commented 5 years ago

@larroy Please comment if you want to take this issue.

larroy commented 5 years ago

Hi