joschu / cgt

Computation Graph Toolkit
Other
628 stars 87 forks source link

Parameter indexing leads to an assertion error in module definition #56

Closed sbos closed 8 years ago

sbos commented 8 years ago

It seems that one cannot define a module that computes parameter gradient wrt expression where the parameter is indexed. In this code example one of two equivalent module definitions fails with assertion error.

import cgt
import cgt.nn as nn
import numpy as np

#does work
z = cgt.scalar()
theta = nn.parameter(np.random.rand(2) - 0.5)
x = cgt.sqrt(z + cgt.dot(theta, theta))

g = cgt.grad(x, [theta])
m = nn.Module([z], g)

#doesn't work
z = cgt.scalar()
theta = nn.parameter(np.random.rand(2) - 0.5)
x = cgt.sqrt(z + cgt.square(theta[0]) + cgt.square(theta[1]))

g = cgt.grad(x, [theta])
#function can be created successfully
f = cgt.function([z], g)
#assertion error
m = nn.Module([z], g)

Error details:

Traceback (most recent call last):
  File "module.py", line 22, in <module>
    m = nn.Module([z], g)
  File "/Users/sbos/projects/cgt/cgt/nn.py", line 16, in __init__
    self.c = core.Composition(inputs, outputs)
  File "/Users/sbos/projects/cgt/cgt/core.py", line 2358, in __init__
    dio = set(differentiably_influences(outputs))
  File "/Users/sbos/projects/cgt/cgt/core.py", line 615, in differentiably_influences
    for (p,d) in utils.safezip(node.parents, node.get_diff()):
  File "/Users/sbos/projects/cgt/cgt/utils.py", line 43, in safezip
    assert len(x) == len(y)
AssertionError
joschu commented 8 years ago

Fixed by b340ff2