In the current form, we cannot get gradients with respect to parameters referenced from an outside scope. For example:
m = torch.tensor(1., requires_grad=True)
x = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
def f(v):
x, b = v
return m.item() * x + b
y = JuliaFunction.apply(f, torch.cat(torch.atleast_1d(x, b)))
y.backward()
print(m.grad, x.grad, b.grad)
>>> None tensor(1.) tensor(1.)
This is, of course, unsurprising due to the .item() detaching m from the graph, but doing m * x + b directly also — again unsurprisingly — fails with TypeError: unsupported operand type(s) for *: 'Tensor' and 'RealValue'.
Despite wanting this feature, it's unclear how this could work — ForwardDiff won't ever know about m and therefore cannot run it through the function as a Dual. There may be a way to execute f, recover m from the torch graph, and automatically construct the pure version of the function, but the torch graph doesn't expose intermediate values on the python side, merely the operations in the graph. The values of leaf nodes are, however, exposed, so the computation could be reconstructed from the leaves, but this very frequently duplicate computation.
Maybe juliacall could be extended to support operations with Tensors, but y would then somehow need to retain its connection to m in the graph despite being returned from Julia.
It's not clear to me that ReverseDiff would help resolve this either.
For context, chirho expects dynamics to be Callable[[Dict[Tensor]], Dict[Tensor]], where the input is a dictionary of states and the output is a dictionary of dstates, where parameters are expected to be referenced from within the callable. For example:
In the current form, we cannot get gradients with respect to parameters referenced from an outside scope. For example:
>>> None tensor(1.) tensor(1.)
This is, of course, unsurprising due to the
.item()
detachingm
from the graph, but doingm * x + b
directly also — again unsurprisingly — fails withTypeError: unsupported operand type(s) for *: 'Tensor' and 'RealValue'
.Despite wanting this feature, it's unclear how this could work —
ForwardDiff
won't ever know aboutm
and therefore cannot run it through the function as aDual
. There may be a way to executef
, recoverm
from the torch graph, and automatically construct the pure version of the function, but the torch graph doesn't expose intermediate values on the python side, merely the operations in the graph. The values of leaf nodes are, however, exposed, so the computation could be reconstructed from the leaves, but this very frequently duplicate computation.Maybe
juliacall
could be extended to support operations withTensor
s, buty
would then somehow need to retain its connection tom
in the graph despite being returned from Julia.It's not clear to me that
ReverseDiff
would help resolve this either.For context,
chirho
expects dynamics to beCallable[[Dict[Tensor]], Dict[Tensor]]
, where the input is a dictionary of states and the output is a dictionary of dstates, where parameters are expected to be referenced from within the callable. For example:Note that this is non-blocking, and we're proceeding as if this won't be possible.