apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

row_sparse_pull,push row_sparse gradient is too slow,it has 10+ times difference #11299

Open coldsheephot opened 6 years ago

coldsheephot commented 6 years ago

result:

when I use row_sparse api to develop deep gradient_compression algorithm ,i find the api is too slowly than dense gradient, can you help me to save more time ?

kv: local,ctx=mx.cpu()
row_sparse push and pull time: 12.5247938633
default push and pull time: 0.194673061371
kv: device,ctx=mx.gpu(0)
row_sparse push and pull time: 3.08180809021
default push and pull time: 0.197321891785

code:

In this code ,I even don't use merge and update, and also tostype even use once, but it also too slowly than dense gradient


import mxnet as mx
import numpy as np

shape = (512*512*3*3, )
key = 3

name = 'local'
kv = mx.kv.create(name)
context = mx.cpu()

for i in range(100):
    kv.init(i, mx.nd.zeros(shape=shape,ctx=context).tostype('row_sparse'))

for i in range(100,200):
    kv.init(i, mx.nd.zeros(shape=shape,ctx=context))

a = mx.nd.ones(shape=shape,ctx=context).tostype('row_sparse')
all_row_ids = mx.nd.array(np.arange(shape[0]), dtype=np.int64)

def test_row_sparse_pull():
    out = a
    for i in range(100):
        kv.push(i, a)
        kv.row_sparse_pull(i, out=out, priority=i, row_ids=all_row_ids)
        # out.wait_to_read()
        # mx.base._LIB.MXNDArrayWaitToWrite(a.handle)
    mx.nd.waitall()

b = mx.nd.ones(shape=shape,ctx=context)

def test_default_pull():
    out = b
    for i in range(100, 200):
        kv.push(i, b)
        kv.pull(i, out=out, priority=i)
        # mx.base._LIB.MXNDArrayWaitToWrite(b.handle)
        # out.wait_to_read()
    mx.nd.waitall()

import time
t1 = time.time()
test_row_sparse_pull()
t2 = time.time()
print "row_sparse push and pull time:", t2-t1

t1 = time.time()
test_default_pull()
t2 = time.time()

print "default push and pull time:", t2-t1
coldsheephot commented 6 years ago

and there has another problem: when I delete kv.push(i,a) in test_row_sparse_pull() function,it needs more time than use push,it is very strange!!

row_sparse push and pull time: 8.14685201645

def test_row_sparse_pull():
    out = a
    for i in range(100):
        # kv.push(i, a)
        kv.row_sparse_pull(i, out=out, priority=i, row_ids=all_row_ids)
        # out.wait_to_read()
        # mx.base._LIB.MXNDArrayWaitToWrite(a.handle)
    mx.nd.waitall()

eric-haibin-lin commented 6 years ago

row_sparse_pull will be fast only when the number of row ids is very small e.g. kv.row_sparse_pull(i, out=out, priority=i, row_ids=mx.nd.array([10]))

coldsheephot commented 6 years ago

@eric-haibin-lin when I push the sparse gradient, after merge,update ,I can get the weight; but I can not get the which row has values and which row doesn't has values. how can I know the row_ids in single machine,I can know the indices from the sparse gradient,but in distributed training,different worker has different indices,when I pull from the ps , I also can not know the final row_ids

kalyc commented 6 years ago

Thanks for submitting this issue @coldsheephot @sandeep-krishnamurthy could you add labels "Performance", "Sparse" to this?

eric-haibin-lin commented 6 years ago

For minibatch training usually you derive rowids from the sparse data in the minibatch. For checkpointing, you need to pull all rowids. RowSparseNDArray doesn't have a prune method yet.. Maybe worth adding

coldsheephot commented 6 years ago

In distributed training, may I get the real indices which have values? can you add the method @eric-haibin-lin @kalyc @sandeep-krishnamurthy

eric-haibin-lin commented 6 years ago

@coldsheephot are you working on multi-device or multi-machine case? I have plan to extend it for multi-device mode

coldsheephot commented 6 years ago

@eric-haibin-lin yes. I want to use the feature for multi-device and multi-machine case. Thanks.

coldsheephot commented 6 years ago

How long does it take to solve those problems???I am very anxious

pinaraws commented 5 years ago

@mxnet-label-bot add[Distributed]