SciML / juliatorch

Convert Julia functions to PyTorch autograd functions
MIT License
5 stars 0 forks source link

Unlikely Feature Request: Support for Impure Functions (Not Blocking) #13

Open azane opened 11 months ago

azane commented 11 months ago

In the current form, we cannot get gradients with respect to parameters referenced from an outside scope. For example:

m = torch.tensor(1., requires_grad=True)
x = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

def f(v):
    x, b = v
    return m.item() * x + b

y = JuliaFunction.apply(f, torch.cat(torch.atleast_1d(x, b)))
y.backward()

print(m.grad, x.grad, b.grad)

>>> None tensor(1.) tensor(1.)

This is, of course, unsurprising due to the .item() detaching m from the graph, but doing m * x + b directly also — again unsurprisingly — fails with TypeError: unsupported operand type(s) for *: 'Tensor' and 'RealValue'.

Despite wanting this feature, it's unclear how this could work — ForwardDiff won't ever know about m and therefore cannot run it through the function as a Dual. There may be a way to execute f, recover m from the torch graph, and automatically construct the pure version of the function, but the torch graph doesn't expose intermediate values on the python side, merely the operations in the graph. The values of leaf nodes are, however, exposed, so the computation could be reconstructed from the leaves, but this very frequently duplicate computation.

Maybe juliacall could be extended to support operations with Tensors, but y would then somehow need to retain its connection to m in the graph despite being returned from Julia.

It's not clear to me that ReverseDiff would help resolve this either.

For context, chirho expects dynamics to be Callable[[Dict[Tensor]], Dict[Tensor]], where the input is a dictionary of states and the output is a dictionary of dstates, where parameters are expected to be referenced from within the callable. For example:

class SIRDynamics(pyro.nn.PyroModule):
    def __init__(self):
        super().__init__()

        self.beta = pyro.param("beta", torch.tensor(0.5), constraints.positive)
        self.gamma = pyro.param("gamma", torch.tensor(0.7), constraints.positive)

    def forward(self, X: State[torch.Tensor]):
        dX: State[torch.Tensor] = State()

        dX["S"] = -self.beta * X["S"] * X["I"]
        dX["I"] = self.beta * X["S"] * X["I"] - self.gamma * X["I"] 
        dX["R"] = self.gamma * X["I"]

        return dX

Note that this is non-blocking, and we're proceeding as if this won't be possible.

azane commented 11 months ago

See at the end and in the comments of #15 for why implementing this feature maybe wouldn't play to the strengths of a diffeqpy backend for chirho.