joschu / cgt

Computation Graph Toolkit
Other
628 stars 87 forks source link

Error on tutorial #44

Closed avivt closed 8 years ago

avivt commented 8 years ago

I get an assertion error (assert newnewnode.typ == orig.typ) on the simplify command in the tutorial code: cgt.print_expr(cgt.simplify([dLdw])[0]);

I am running on windows, with only the python installed (no Cython or Cuda installed).

This is the code I am running: import cgt a = cgt.scalar(name='a') # float-valued scalar, with optional name provided b = cgt.scalar(name='b') n = cgt.scalar(name='n', dtype='int64') # integer scalar

c = (an + bn)**(1.0/n)

f = cgt.function([a,b,n], c) print f(8,15,2)

X_nk = cgt.matrix("X") y_n = cgt.vector("y") w_k = cgt.vector("w") b = cgt.scalar("b") ypred_n = X_nk.dot(w_k) + b L = cgt.sum(cgt.square(ypred_n - y_n)) print "L = ", cgt.print_expr(L) print X_nk.ndim, str(X_nk.shape), X_nk.dtype grads = dLdw, dLdb = cgt.grad(L, [w_k, b]) print "Loss and gradient objects", dLdw, dLdb print "Pretty-printed gradient: ", cgt.print_expr(cgt.simplify([dLdw])[0]);

And this is the error I get:

Traceback (most recent call last): File "C:/test_cgt.py", line 23, in cgt.print_expr(cgt.simplify([dLdw])[0]); File "C:\CGT\cgt\core.py", line 2688, in simplify return simplify_and_analyze(xs)[0] File "C:\CGT\cgt\core.py", line 2533, in simplify_and_analyze for output in outputs: update_simplify_map(output, analysis, repl) File "C:\CGT\cgt\core.py", line 2600, in update_simplify_map maybe_pair = process_top_stack_item_and_maybe_get_replacement(stack, analysis, repl) File "C:\CGT\cgt\core.py", line 2567, in process_top_stack_item_and_maybe_get_replacement assert newnewnode.typ == orig.typ AssertionError

I also put a break-point on the assertion line, and got that newnewnode.typ = Tensor(i4,0) orig.typ = Tensor(i8,0)