necludov / jam

Implementation of Action Matching
https://arxiv.org/abs/2210.06662
MIT License
36 stars 6 forks source link

p(t) in Algorithm 2 #3

Open boxaio opened 9 months ago

boxaio commented 9 months ago

In Algorithm 2 (in practice), you sample the time points t \sim p(t), where p(t) can be viewed as a proposal importance sampling distribution. One can take p(t) to be estimated using Eq.(85), as you mentioned in Appendix.C. But in this code repository (see losses.py) you have p_t = time_sampler.invdensity(t) which is according to an uniform distribution (see dynamics.utils). So I wonder how you actually implemented your claim in Algorithm 2 in Appendix C.

necludov commented 9 months ago

Hi @boxaio! I haven't implemented this in the JAX version. I have a pytorch implementation of the same algorithm, which is not public. Here is the code for the class implementing the loss function

class AdaptiveLoss:
    def __init__(self, net, config, n=100, beta=0.99):
        self.t0, self.t1 = config.model.t0, config.model.t1
        self.alpha, self.beta = config.train.alpha, beta
        self.timesteps = np.linspace(self.t0, self.t1, n)
        self.dt = (self.t1-self.t0)/(n-1)
        self.rank = get_rank()
        self.ws = get_world_size()

        self.q_t, self.w, self.dwdt = get_q(config)
        self.boundary_conditions = (self.w(torch.tensor(self.t0)).item() != 0.0,
                                    self.w(torch.tensor(self.t1)).item() != 0.0)
        print('boundary conditions are: ', self.boundary_conditions)
        config.train.boundary_conditions = self.boundary_conditions

        self.s = get_s(net, config)

        self.buffer = {'values': [],
                       'times': [],
                       'size': 100,
                       'mean': np.zeros_like(self.timesteps),
                       'var': np.ones_like(self.timesteps),
                       'p': np.ones_like(self.timesteps),
                       'u0': 0.5}
        self.construct_dist()

        meters = [DDPAverageMeter('train_loss'),
                  DDPAverageMeter('dsdx_std'),
                  DDPAverageMeter('dsdt_std'),
                  DDPAverageMeter('s_1_std'),
                  DDPAverageMeter('s_0_std'),
                  DDPAverageMeter('s_std')]
        self.meters = dict((m.name,m) for m in meters)

    def load_state_dict(self, buffer_dict):
        self.buffer = buffer_dict
        self.construct_dist()

    def state_dict(self):
        return self.buffer

    def construct_dist(self):
        dt, t = self.dt, self.timesteps
        p = self.buffer['p']
        self.fp = scipy.interpolate.interp1d(t, p, kind='linear')
        self.dpdt = scipy.interpolate.interp1d(t, np.concatenate([p[1:]-p[:-1], p[-1:]-p[-2:-1]])/dt, kind='zero')
        intercept = lambda t: self.fp(t)-self.dpdt(t)*t
        t0_interval = scipy.interpolate.interp1d(t, t, kind='zero')
        mass = np.concatenate([np.zeros([1]), ((p[1:]+p[:-1])*dt/2).cumsum()[:-1], np.ones([1])])
        F0_interval = scipy.interpolate.interp1d(t, mass, kind='zero')
        F0_inv = scipy.interpolate.interp1d(mass, t, kind='zero')
        def F(t):
            t0_ = t0_interval(t)
            F0_ = F0_interval(t)
            k, b = self.dpdt(t), intercept(t)
            output = 0.5*k*(t**2-t0_**2) + b*(t-t0_)
            return F0_ + output 

        def F_inv(y):
            t0_ = F0_inv(y)
            F0_ = F0_interval(t0_)
            k, b = self.dpdt(t0_), intercept(t0_)
            c = y - F0_
            c = c + 0.5*k*t0_**2 + b*t0_
            D = np.sqrt(b**2 + 2*k*c)
            output = (-b + D) * (np.abs(k) > 0)  + c/b * (np.abs(k) == 0.0)
            output[np.abs(k) > 0] /= k[np.abs(k) > 0]
            return output

        self.F_inv = F_inv

    def sample_t(self, n, device):
        u = (self.buffer['u0'] + np.sqrt(2)*np.arange(n*self.ws)) % 1
        self.buffer['u0'] = (self.buffer['u0'] + np.sqrt(2)*n*self.ws) % 1
        u = u[self.rank*n:(self.rank+1)*n]
        t = self.F_inv(u)
        assert ((t < 0.0).sum() == 0) and ((t > 1.0).sum() == 0)
        p_t, dpdt = self.fp(t), self.dpdt(t)
        p_0, p_1 = self.fp(self.t0*np.ones_like(t)), self.fp(self.t1*np.ones_like(t))
        t = torch.from_numpy(t).to(device).float()
        p_t, dpdt = torch.from_numpy(p_t).to(device).float(), torch.from_numpy(dpdt).to(device).float()
        p_0, p_1 = torch.from_numpy(p_0).to(device).float(), torch.from_numpy(p_1).to(device).float()
        return t, p_t, dpdt

    def update_history(self, new_p, t, p_t):
        new_p, t, p_t = new_p.cpu().numpy().flatten(), t.cpu().numpy().flatten(), p_t.cpu().numpy().flatten()
        weights = np.exp(-np.abs(self.timesteps.reshape(-1, 1) - t.reshape(1,-1))*1e2)
        weights = weights/weights.sum(1,keepdims=True)
        self.buffer['mean'] = self.beta*self.buffer['mean'] + (1-self.beta)*(weights@new_p)
        mean_func = scipy.interpolate.interp1d(self.timesteps, self.buffer['mean'], kind='linear')
        self.buffer['var'] = self.beta*self.buffer['var'] + (1-self.beta)*(weights@((mean_func(t) - new_p)**2))
        p = np.sqrt(self.buffer['var'])
        p = (1.0-self.alpha)*p/((p[1:]+p[:-1])*self.dt/2).sum() + self.alpha/(self.t1-self.t0)
        self.buffer['p'] = p
        self.construct_dist()

    def eval_loss(self, x):
        q_t, w, dwdt, s = self.q_t, self.w, self.dwdt, self.s
        assert (2 == x.dim())
        t_0, t_1 = self.t0, self.t1
        device = x.device
        bs = x.shape[0]
        t, p_t, dpdt = self.sample_t(bs, device)
        while (x.dim() > t.dim()): t = t.unsqueeze(-1)
        x_t, _ = q_t(x, t)
        x_t.requires_grad, t.requires_grad = True, True
        s_t = s(t, x_t)
        assert (2 == s_t.dim())
        dsdt, dsdx = torch.autograd.grad(s_t.sum(), [t, x_t], create_graph=True, retain_graph=True)
        x_t.requires_grad, t.requires_grad = False, False

        loss = (0.5*(dsdx**2).sum(1, keepdim=True) + dsdt.sum(1, keepdim=True))*w(t)
        self.meters['dsdx_std'].update((0.5*(dsdx**2).sum(1)*w(t).squeeze()).detach().cpu().std())
        self.meters['dsdt_std'].update((dsdt.sum(1)*w(t).squeeze()).detach().cpu().std())
        loss = loss + s_t*dwdt(t)
        self.meters['s_std'].update((s_t*dwdt(t)).squeeze().detach().cpu().std())
        loss = loss.squeeze()/p_t
        time_loss = loss.detach()*p_t

        s_1_std, s_0_std = 0.0, 0.0
        if self.boundary_conditions[0]:
            t_0 = t_0*torch.ones([bs, 1], device=device)
            x_0, _ = q_t(x, t_0)
            left_bound = (s(t_0,x_0)*w(t_0)).squeeze()
            loss = loss + left_bound
            self.meters['s_0_std'].update(left_bound.detach().cpu().std())
        if self.boundary_conditions[1]:
            t_1 = t_1*torch.ones([bs, 1], device=device)
            x_1, _ = q_t(x, t_1)
            right_bound = (-s(t_1,x_1)*w(t_1)).squeeze()
            loss = loss + right_bound
            self.meters['s_1_std'].update(right_bound.detach().cpu().std())

        self.meters['train_loss'].update(loss.detach().mean().cpu())
        self.update_history(gather(time_loss), gather(t), gather(p_t))
        return loss.mean(), self.meters

    def get_dxdt(self):
        def dxdt(t, x):
            return torch.autograd.grad(self.s(t, x).sum(), x, create_graph=True, retain_graph=True)[0]
        return dxdt