tsinghua-fib-lab / Spatio-temporal-Diffusion-Point-Processes

A diffusion-based framework for spatio-temporal point processes
55 stars 4 forks source link

Prediction Code for baseline models NSTPP and DeepSTPP #5

Closed yuss01 closed 1 month ago

yuss01 commented 1 month ago

Dear author, after reading your research on the DSTPP model, I have been deeply inspired. However, regarding the baseline models NSTPP and DeepSTPP you used, I noticed that there is no content about prediction in their articles or code. How did you implement prediction and calculate prediction metrics in these two models? This includes time domain prediction (RMSE metric) and spatial domain prediction (Euclidean distance metric). If it's convenient, I would appreciate it if I could obtain the complete code for the two baseline models (NSTPP and DeepSTPP) you used! Thank you!

YuanYuan98 commented 1 month ago

Thanks for your attention! You can find the prediction implementation for DeepSTPP in link.

For NSTPP prediction, it requires two additional functions to sample the time interval and the spatial location. For temporal sampling, you need to add the following functions in class NeuralPointProcess(TemporalPointProcess) (NSTPP) :

def get_intensity(self, state):
    return self.ode_solver.func.get_intensity(state)

def compute_intensity_given_past(self,last_time, current_time, last_state, nlinspace=1):
    state_traj = self.ode_solver.integrate(last_time, current_time, last_state, nlinspace = nlinspace, method="dopri5")
    state = tuple(s[-1] for s in state_traj)
    Lambda, tpp_state = state
    self.intensity = self.get_intensity(tpp_state)
    return state

def sample_time(self, last_time, last_state, intensity, input_mask):
    device = last_time.device
    NN = last_time.shape[0]
    u = torch.tensor([1.5 for _ in range(NN)]).to(device)
    last_state = (torch.zeros(NN).to(last_time), last_state)

    intensity_hazard = intensity

    mask = torch.zeros([NN,1]).to(device).bool()
    assert mask.shape == input_mask.shape
    mask = mask | (1-input_mask).bool()

    time_current = last_time

    while (1-mask.float()).bool().any():
        Exp = torch.distributions.Exponential(torch.tensor(1.0))
        E = torch.tensor([[Exp.sample()] for _ in range(NN)]).to(device)
        Uni = torch.distributions.uniform.Uniform(torch.tensor(0.0),torch.tensor(1.0))
        U = torch.tensor([[Uni.sample()] for _ in range(NN)]).to(device)

        assert E.shape == intensity_hazard.shape
        interval = E/intensity_hazard

        assert time_current.shape == interval.shape, mask.shape == interval.shape
        time_current += interval * (1-mask.float())

        last_state = self.compute_intensity_given_past(last_time.clone().detach(), time_current.clone().detach(), last_state)

        last_time = time_current

        assert self.intensity.shape == U.shape == intensity_hazard.shape
        u = (U * intensity_hazard / self.intensity)

        intensity_hazard = self.intensity

        mask = mask | (u<1)

    return time_current

For spatial sampling, here I provide two examples of attncnf and conf_gmm.

For attncnf, add the sampling function to SelfAttentiveCNF(nn.Module) (SelfAttentiveCNF):

def sample_spatial(self, last_times, current_times, spatial_locations, input_mask=None, aux_state=None):
    """
    Args:
        nsamples: int
        event_times: (N, T)
        current_times: (N, T, D)
        input_mask: (N, T) or None
        aux_state: (N,T, D_a)

    Returns:
        Samples from the spatial distribution at event times, of shape (nsamples, N, T, D).
    """

    N, T = last_times.shape

    t_embed = self.t_embedding(last_times) / math.sqrt(self.t_embedding_dim)

    if aux_state is not None:
        inputs = [spatial_locations, aux_state, t_embed]
    else:
        inputs = [spatial_locations, t_embed]

    # attention layer uses (T, N, D) ordering.
    inputs = [inp.transpose(0, 1) for inp in inputs]
    norm_fn = max_rms_norm([a.shape for a in inputs])
    x = torch.cat(inputs, dim=-1)

    self.odefunc.set_shape(x.shape)

    x = x.reshape(T * N, -1)
    last_times = last_times.transpose(0, 1).reshape(T * N)
    current_times = current_times.transpose(0, 1).reshape(T * N)

    t0 = last_times + self.time_offset
    t1 = current_times + self.time_offset

    assert (t1 >= t0).all()

    z, _ = self.cnf.integrate(t0, t1, x, torch.zeros_like(last_times), norm=norm_fn)
    z = z[:, :self.dim]  # (T * N, D)

    return z

For cond_gmm, you can directly find sample_spatial function in ConditionalGMM(nn.Module) (cond_gmm).