apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.74k stars 6.81k forks source link

Multiple trainers within a single worker using a distributed KVStore #13054

Open varunrajk opened 5 years ago

varunrajk commented 5 years ago

Description

The Gluon Trainer step method uses enumerations as keys to push and pull gradients/parameters from kvstore. Using two trainers within a single worker script (in a distributed learning setting) can cause an issue because each trainer uses the same set of keys on the distributed KVStore.

Minimum reproducible example

Here is a simple example script that demonstrates this issue:

import mxnet as mx
import numpy as np
from mxnet import autograd

def trainer_test(kvstore, problem_descr):
    # model
    m = mx.gluon.nn.Dense(1, use_bias=False)
    m.collect_params().initialize(mx.init.Constant(problem_descr['init']), ctx=mx.cpu())

    # trainer
    trainer = mx.gluon.Trainer(params=m.collect_params(),
                                              optimizer='sgd',
                                              optimizer_params={'learning_rate': problem_descr['lr']},
                                              kvstore=kvstore)

    # update parameter
    with autograd.record():
        y = m(mx.nd.ones((1, 1)) * problem_descr['x'])
        loss_a = mx.nd.abs(problem_descr['target'] - y)

    loss_a.backward()

    trainer.step(1)

    # get new parameter value
    v = np.asscalar(m.collect_params().get('weight').list_data()[0].asnumpy())

    # get expected value
    grad = np.sign(problem_descr['target'] - (problem_descr['init'] * problem_descr['x'])) * problem_descr['x']
    exp_v = problem_descr['init'] + grad * problem_descr['lr'] * kvstore.num_workers

    # print the updated parameter value and the expected value
    if kv.rank == 0:
        print(f'updated paramter value: {np.round(v, 3)}, expected value: {exp_v}')

if __name__ == '__main__':
    kv = mx.kv.create('dist_sync')
    trainer_test(kvstore=kv, problem_descr={'lr': 0.1, 'init': 1., 'x': 2, 'target': 4})
    trainer_test(kvstore=kv, problem_descr={'lr': 0.1, 'init': 3., 'x': 2, 'target': 4})

The trainer_test method tests a single step update on a simplified regression problem. It takes as inputs

Steps to reproduce

Execute the script by using mxnet's launch.py tool with 2 or more workers and a local launcher as follows:

launch.py -n 2 --launcher local python3 trainer_test.py

The script execution freezes after evaluating the first trainer_test call.

What have you tried to solve it?

Replacing the kvstore keys in the trainer to use unique parameter names (for example, param.name) solves this issue.

frankfliu commented 5 years ago

@mxnet-label-bot [KVStore, Thread Safety]

vsuhasm commented 4 years ago

Hi @szha @frankfliu, is there any way to get around this? I have 2 sets of parameters in the same model, each needing different optimizer and optimizer_params. When I create 2 gluon trainers with same kvstore, I run into this issue.

nicklhy commented 4 years ago

Any updates here ? Got the same problem.