DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

RuntimeError: Trying to backward through the graph a second time #141

Closed josyulakrishna closed 2 years ago

josyulakrishna commented 2 years ago

Hello I'm trying to run the following code for my project, where I'm encountering an error which says

Error:


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already 

been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify 

retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after 

calling backward.

Code:

import torch.nn.functional as F
import torch.nn as nn
import copy

from torch.autograd import grad

# from params import set_initial_conditions

import torch

def set_initial_conditions(n_agents):
    if n_agents == 12:
        #px,py,qx,qy
        x0 = torch.tensor([[0], [0.5], [-3], [5],
                           [0], [0.5], [-3], [3],
                           [0], [0.5], [-3], [1],
                           [0], [-0.5], [-3], [-1],
                           [0], [-0.5], [-3], [-3],
                           [0], [-0.5], [-3], [-5],
                           # second column
                           [-0], [0.5], [3], [5],
                           [-0], [0.5], [3], [3],
                           [-0], [0.5], [3], [1],
                           [0], [-0.5], [3], [-1],
                           [0], [-0.5], [3], [-3],
                           [0], [-0.5], [3], [-5],
                           ])
        xbar = torch.tensor([[0], [0], [3], [-5],
                             [0], [0], [3], [-3],
                             [0], [0], [3], [-1],
                             [0], [0], [3], [1],
                             [0], [0], [3], [3],
                             [0], [0], [3], [5],
                             # second column
                             [0], [0], [-3], [-5],
                             [0], [0], [-3], [-3],
                             [0], [0], [-3], [-1],
                             [0], [0], [-3], [1],
                             [0], [0], [-3], [3],
                             [0], [0], [-3], [5.0],
                             ])
    else:
        x0 = (torch.rand(4*n_agents, 1)-0.5)*10
        xbar = (torch.rand(4*n_agents, 1)-0.5)*10
    return x0, xbar

# X  = (J-R)*dV/dx + Fy + Gu
# Y = G*dV/dx
# forward = return (J-R)*dV/dx

class SystemEnv(nn.Module):
    def __init__(self, V, K, n_agents=1, xbar=None, ctls=None, batch_size=1, **kwargs):
        """ Initialize the environment. Here we represent the system.
        """
        super().__init__()
        self.K = K
        self.V = V
        self.k = torch.tensor(1.0)
        self.b = torch.tensor(0.2)
        self.m = torch.tensor(1.0)
        J = torch.tensor([[0, 0, -1, 0],
                          [0, 0, 0, -1],
                          [1., 0, 0, 0],
                          [0, 1., 0, 0]])
        R = torch.tensor([[self.b, 0, 0, 0],
                          [0, self.b, 0, 0],
                          [0, 0, 0, 0],
                          [0, 0, 0, 0]])
        #number of agents
        self.n_agents = n_agents
        #dimension of state space q,p = (2,2)
        self.ni = 4
        self.n = self.ni * n_agents
        if ctls is None:
            ctls = torch.ones(1, n_agents)
            # n_of_inputs x n_masses
        self.interconnection = ctls
        self.J = torch.zeros((self.n, self.n))
        self.R = torch.zeros((self.n, self.n))
        for i in range(0, n_agents):
            self.J[self.ni * i:self.ni * (i + 1), self.ni * i:self.ni * (i + 1)] = J
            self.R[self.ni * i:self.ni * (i + 1), self.ni * i:self.ni * (i + 1)] = R
        if xbar is None:
            xbar = torch.zeros(self.n, 1)
        self.xbar = xbar
        self.B = torch.tensor([[1.0, 0], [0, 1.0]])
        self.batch_size = batch_size

    def g(self, t, x):
        # g = torch.zeros((self.n, 2*int(self.interconnection.sum())))
        # idx = 0
        # for i, j in self.interconnection.nonzero():
        #     g[(4*j), idx] = 1
        #     g[(4*j)+1, idx+1] = 1
        #     idx += 2
        # return g
        g_agent = torch.tensor([[1.0, 0], [0, 1.0], [0, 0], [0, 0]])
        self.g_agent = copy.deepcopy(g_agent)
        g = torch.zeros(0, 0)
        for i in range(self.n_agents):
            g = torch.block_diag(g, g_agent)
        return g

    def H(self, t, x):
        delta_x = x - self.xbar
        Q_agent = torch.diag(torch.tensor([1 / self.m, 1 / self.m, self.k, self.k]))
        Q = torch.zeros((self.n, self.n))
        for i in range(self.n_agents):
            Q[self.ni * i:self.ni * (i + 1), self.ni * i:self.ni * (i + 1)] = Q_agent
        R_agent = torch.diag(torch.tensor([1 / self.m, 1 / self.m, self.k, self.k]))
        return 0.5 * F.linear(F.linear(delta_x.T, Q), delta_x.T)

    def gradH(self, t, x):
        x = x.requires_grad_(True)
        return torch.autograd.grad(self.H(t, x), x, allow_unused=False, create_graph=True)[0]

    def f(self, t, x):
        dHdx = self.gradH(t, x)
        return F.linear(dHdx.T, self.J - self.R)

    def _dynamics(self, t, x, u):
        # p = torch.cat(x[0::4], x[1::4], 0)
        q = torch.stack((x[:, 2::4], x[:, 3::4]), dim=2).view(x.shape[0],self.n_agents*2).to(0)
        p = torch.stack((x[:, 0::4], x[:, 1::4]), dim=2).view(x.shape[0], self.n_agents * 2).to(0)
        # [p;q] = [J-R]*delV+[B,0]*u
        delVq = self._energy_shaping(q)
        # delVp = p from formulation
        r1 = torch.stack((delVq, p), dim=2).to(0)
        r1 = r1.view(x.shape[0], self.n_agents*4)
        JR = (self.J - self.R).to(0)
        # input, matrix
        result  = torch.zeros(x.shape[0], self.n_agents*4).to(0)
        u = u.view(x.shape[0], self.n_agents * 2).to(0)
        for i in range(r1.shape[0]):
            par1 = torch.matmul(JR, r1[i, :])
            g = self.g(t, x).to(0)
            # uθ = −BInv*∇qV(q) − K∗(t, q, p)B*qdot
            par2 = F.linear(g, u[i,:])
            result[i,:] = torch.add(par1, par2).to(0)
        return result

    def _energy_shaping(self,q):
        # dVdx = grad(self.V(q).sum(), q, create_graph=True)[0]
        dVdx = grad(self.V(q).sum(), q, create_graph=True, retain_graph=True)[0]
        return -1*dVdx

    # def _energy(self,t,x):
    #     Q_agent = torch.diag(torch.tensor([1 / 2*self.m, self.k]))
    #     temp_x = torch.zeros(self.n_agents*2)
    #     temp_p = torch.zeros(self.n_agents*2)
    #     x_temp = x
    #     x_temp = x.view(self.n_agents, 4)
    #     for i in range(self.n_agents):
    #         for j in range(len(x[i])):
    #             temp_x[i]=(0.5*self.m)*()
    #             temp_p[i] = (0.5 * self.m)
    #
    #     F.linear(Q_agent, torch.cat(torch.cdist(x[:,2:...],torch.zeros_like(x[:,2:...])),torch.cdist(x[...,:2]-xbar)))

    def _damping_injection(self,x):
        # x = [pdot, qdot]
        q = torch.stack((x[:, 2::4], x[:, 3::4]), dim=2).view(x.shape[0],self.n_agents*2).requires_grad_(True)
        Kmat = torch.diag(self.K(x.to(0)).ravel())
        return -1*F.linear(Kmat, q.view(1, x.shape[0]*self.n_agents*2).to(0))

    def forward(self, t, x):
        # x = [p,q]
        print("in forward")
        x = x.requires_grad_(True)
        #batch_size, n_agents*4, n_agents = x
        q = torch.stack((x[:, 2::4], x[:, 3::4]), dim=2).view(x.shape[0],self.n_agents*2).requires_grad_(True).to(0)
        u1 = self._energy_shaping(q)
        u1 = u1.view(x.shape[0]*self.n_agents*2, 1 )
        u2 = self._damping_injection(x)
        u = u1+u2
        return  (self._dynamics(t,x,u),u)
        # return self.f(t, x).T

class AugmentedDynamics(nn.Module):
    # "augmented" vector field to take into account integral loss functions
    def __init__(self, f, int_loss):
        super().__init__()
        self.f = f
        self.int_loss = int_loss
        self.nfe = 0.

    def forward(self, t, x):
        self.nfe += 1
        x = x[:,:f.n_agents*4]
        (dxdt,u) = self.f.forward(t, x)
        dldt =   self.int_loss.int_loss(x, u, self.f)
        return torch.cat([dxdt, dldt], 1).cpu()

class ControlEffort(nn.Module):
    # control effort integral cost
    def __init__(self, f, x0, xbar, dt):
        super().__init__()
        self.f = f

        self.x = torch.cat((x0[0::4], x0[1::4]), dim=1)
        self.x = self.x.repeat(f.batch_size, 1)
        self.x = self.x.reshape(f.batch_size,f.n_agents*2)

        self.xbar = torch.cat((xbar[0::4], xbar[1::4]), dim=1)
        self.xbar = self.xbar.repeat(f.batch_size, 1)
        self.xbar = self.xbar.reshape(batch_size, f.n_agents*2)

        self.dt = dt

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            q = torch.cat((x[2::4],x[3::4]),0).requires_grad_(True)
            # q = torch.transpose(q, 0, 1)
            u1 = torch.transpose(self.f._energy_shaping(q), 0, 1)
            u2 = self.f._damping_injection(x).to(0)
            u = u1+u2
        return u

    def int_loss(self, x, u, clsys):
        x = x.reshape(x.shape[0],self.f.n_agents,2,2)
        vel = torch.index_select(x, 2, torch.tensor([1]))
        vel = vel.reshape(x.shape[0],self.f.n_agents*2)
        self.x = self.x.cpu()+torch.mul(vel,self.dt)
        self.x = self.x.to(0)
        self.xbar = self.xbar.to(0)
        self.u = u.reshape(self.f.batch_size,self.f.n_agents*2)
        self.clsys = clsys
        lx = self.f_loss_states().reshape(self.f.batch_size,1).to(0)
        lu = self.f_loss_u().reshape(self.f.batch_size,1).to(0)
        lca = self.f_loss_ca()
        loss = lx+lu+lca
        return loss.to(0)

    def f_loss_states(self, test=False):
        # clsys = SystemEnv
        loss_function = nn.MSELoss(reduction='none')
        xbar = self.clsys.xbar
        # steps = t.shape[0]
        if test:
            gamma = 1
        else:
            gamma = 0.95
        loss = loss_function(self.x, self.xbar)
        loss = loss.view(self.f.batch_size, 2*self.f.n_agents)
        loss = loss.sum(dim=1)
        return loss

    def f_loss_u(self):
        loss_u = ((self.u*self.clsys.b) ** 2).sum(dim=1)
        return loss_u

    def f_loss_ca(self, min_dist=0.5):
        steps = self.x.shape[0]
        min_sec_dist = 1.4 * min_dist
        # for i in range(steps):
        #     for j in range(i+1, steps):
        #         dist = torch.norm(self.x[i, :] - self.x[j, :])
        #         if dist < min_sec_dist:
        #             return torch.tensor(1e10)
        # return torch.tensor(0)
        loss_ca_ = torch.zeros(self.f.batch_size,1).to(0)
        for i in range(self.x.shape[0]):
            x = self.x[i,:].view(self.f.n_agents*2,1).to(0)
            clsys = self.clsys
            # collision avoidance:
            # deltax = x[:, 2::4].repeat(1, 1, clsys.n // 4) - x[:, 2::4].transpose(1, 2).repeat(1, clsys.n // 4, 1)
            deltax = x[0::2].repeat(1, clsys.n // 4) - x[0::2].transpose(0, 1).repeat( clsys.n // 4, 1)
            # deltay = x[:, 3::4].repeat(1, 1, clsys.n // 4) - x[:, 3::4].transpose(1, 2).repeat(1, clsys.n // 4, 1)
            deltay = x[1::2].repeat(1, clsys.n // 4) - x[1::2].transpose(0, 1).repeat(clsys.n // 4, 1)
            distance_sq = deltax ** 2 + deltay ** 2
            mask = torch.logical_not(torch.eye(clsys.n // 4)).unsqueeze(0).repeat(steps, 1, 1)
            mask = mask.to(0)
            loss_ca_[i,:] = (1 / (distance_sq + 1e-3) * (distance_sq.detach() < (min_sec_dist ** 2)) * mask).sum() / 2
        return loss_ca_

import pytorch_lightning as pl
import torch.utils.data as data

def weighted_log_likelihood_loss(x, target, weight):
    # weighted negative log likelihood loss
    log_prob = target.log_prob(x)
    weighted_log_p = weight * log_prob
    return -torch.mean(weighted_log_p.sum(1))

class EnergyShapingLearner(pl.LightningModule):
    def __init__(self, model: nn.Module, prior_dist, target_dist, t_span, sensitivity='autograd', n_agents=1):
        super().__init__()
        self.model = model
        self.prior, self.target = prior_dist, target_dist
        self.t_span = t_span
        self.batch_size = batch_size
        self.lr = 5e-3
        self.n_agents = n_agents
        self.weight = torch.ones(n_agents * 4).reshape(1, n_agents * 4)

    def forward(self, x):
        return self.model.odeint(x, self.t_span)

    def training_step(self, batch, batch_idx):
        # sample a batch of initial conditions
        # x0 = self.prior.sample((self.batch_size,))
        n_agents = self.n_agents
        x0 = torch.rand((self.batch_size,n_agents*4))
        # x0, _ = set_initial_conditions(n_agents)
        # Integrate the model
        x0 = torch.cat([x0, torch.zeros(self.batch_size, 1)], -1).to(x0)
        xs, xTl = self(x0)
        xT, l = xTl[-1, :, :2], xTl[-1, :, -1:]

        # Compute loss
        # terminal_loss = weighted_log_likelihood_loss(xT, self.target, self.weight.to(xT))
        integral_loss = torch.mean(l)
        loss = 0 + 0.01 * integral_loss
        return {'loss': loss.cpu()}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
        return [optimizer], [scheduler]

    def train_dataloader(self):
        dummy_trainloader = data.DataLoader(
            data.TensorDataset(torch.Tensor(1, 1), torch.Tensor(1, 1)),
            batch_size=1)
        return dummy_trainloader

# # # # # # # # Parameters # # # # # # # #
n_agents = 2  # agents are not interconnected (even when having a controller). Each of them acts independently
t_end = 5
steps = 101
min_dist = 0.5  # min distance for collision avoidance
# (px, py, qx, qy) - for each agent
hdim = 64
V = nn.Sequential(
    nn.Linear(n_agents*2, hdim),
    nn.Softplus(),
    nn.Linear(hdim, hdim),
    nn.Tanh(),
    nn.Linear(hdim, 1))
K = nn.Sequential(
    nn.Linear(n_agents*4, hdim),
    nn.Softplus(),
    nn.Linear(hdim, (n_agents*2)),
    nn.Softplus())

from torch.distributions import Uniform, Normal

def prior_dist(q_min, q_max, p_min, p_max, device='cpu'):
    # uniform "prior" distribution of initial conditions x(0)=[q(0),p(0)]
    lb = torch.Tensor([q_min, p_min]).to(device)
    ub = torch.Tensor([q_max, p_max]).to(device)
    return Uniform(lb, ub)

def target_dist(mu, sigma, device='cpu'):
    # normal target distribution of terminal states x(T)
    mu, sigma = torch.Tensor(mu).reshape(1, 2).to(device), torch.Tensor(sigma).reshape(1, 2).to(device)
    return Normal(mu, torch.sqrt(sigma))

def weighted_log_likelihood_loss(x, target, weight):
    # weighted negative log likelihood loss
    log_prob = target.log_prob(x)
    weighted_log_p = weight * log_prob
    return -torch.mean(weighted_log_p.sum(1))

from torchdyn.models import ODEProblem

# choose solver and sensitivity method
solver = 'rk4'
sensitivity = 'autograd'

# init to zero par.s of the final layer
# for p in V[-1].parameters(): torch.nn.init.zeros_(p)
# for p in K[-2].parameters(): torch.nn.init.zeros_(p)

# define controlled system dynamics
x0, xbar = set_initial_conditions(n_agents)
batch_size = 4

f = SystemEnv(V.to(0), K.to(0), n_agents=2, xbar=xbar, batch_size = batch_size)

t_span = torch.linspace(0, 3, 30)
dt = t_span[1]-t_span[0]
aug_f = AugmentedDynamics(f, ControlEffort(f,x0,xbar,dt))
# define time horizon

prob = ODEProblem(aug_f, sensitivity=sensitivity, solver=solver)

# train (it can be very slow on CPU)
# (don't be scared if the loss starts very high)
prior = prior_dist(-1, 1, -1, 1) # Uniform "prior" distribution of initial conditions x(0)
target = target_dist([0, 0], [.001, .001]) # Normal target distribution for x(T)
learn = EnergyShapingLearner(prob, prior, target, t_span, batch_size, n_agents=n_agents)
# trainer = pl.Trainer(accelerator="gpu", devices=1, strategy="ddp",max_epochs=650).fit(learn)
trainer = pl.Trainer(accelerator="gpu", devices=1, strategy="ddp",max_epochs=650)
trainer.fit(learn)

any help is deeply appreciated thank you.

josyulakrishna commented 2 years ago

This can be resolved by using

return Variable(dyn).requires_grad_(True) , x = Variable(x.data, requires_grad=True)