Open wddabc opened 7 years ago
Looks like the computation graph breaks on the concatenation operation. MWE:
import minpy.numpy as np from minpy.core import grad def foo_nocat(x): return 3*x def foo_cat(x): catx = np.concatenate([x, x], axis=1) return np.dot(catx, np.array([[1], [2]])) test_x = np.array([[3]]) print grad(foo_nocat)(test_x) # correct_output print grad(foo_cat)(test_x) # should be the same
@ZihengJiang Could you have a look? Also put this in unittest.
@ZihengJiang Any follow-up on this?
Looks like the computation graph breaks on the concatenation operation. MWE: