dmlc / minpy

NumPy interface with mixed backend execution
1.11k stars 112 forks source link

Zero gradient for concatenate #154

Open wddabc opened 7 years ago

wddabc commented 7 years ago

Looks like the computation graph breaks on the concatenation operation. MWE:

import minpy.numpy as np
from minpy.core import grad

def foo_nocat(x):
    return 3*x 

def foo_cat(x):
    catx = np.concatenate([x, x], axis=1)
    return, np.array([[1], [2]]))

test_x = np.array([[3]]) 
print grad(foo_nocat)(test_x)  # correct_output 
print grad(foo_cat)(test_x)  # should be the same
jermainewang commented 7 years ago

@ZihengJiang Could you have a look? Also put this in unittest.

Taco-W commented 7 years ago

@ZihengJiang Any follow-up on this?