Closed yuss01 closed 1 month ago
Thanks for your attention! You can find the prediction implementation for DeepSTPP in link.
For NSTPP prediction, it requires two additional functions to sample the time interval and the spatial location. For temporal sampling, you need to add the following functions in class NeuralPointProcess(TemporalPointProcess)
(NSTPP) :
def get_intensity(self, state):
return self.ode_solver.func.get_intensity(state)
def compute_intensity_given_past(self,last_time, current_time, last_state, nlinspace=1):
state_traj = self.ode_solver.integrate(last_time, current_time, last_state, nlinspace = nlinspace, method="dopri5")
state = tuple(s[-1] for s in state_traj)
Lambda, tpp_state = state
self.intensity = self.get_intensity(tpp_state)
return state
def sample_time(self, last_time, last_state, intensity, input_mask):
device = last_time.device
NN = last_time.shape[0]
u = torch.tensor([1.5 for _ in range(NN)]).to(device)
last_state = (torch.zeros(NN).to(last_time), last_state)
intensity_hazard = intensity
mask = torch.zeros([NN,1]).to(device).bool()
assert mask.shape == input_mask.shape
mask = mask | (1-input_mask).bool()
time_current = last_time
while (1-mask.float()).bool().any():
Exp = torch.distributions.Exponential(torch.tensor(1.0))
E = torch.tensor([[Exp.sample()] for _ in range(NN)]).to(device)
Uni = torch.distributions.uniform.Uniform(torch.tensor(0.0),torch.tensor(1.0))
U = torch.tensor([[Uni.sample()] for _ in range(NN)]).to(device)
assert E.shape == intensity_hazard.shape
interval = E/intensity_hazard
assert time_current.shape == interval.shape, mask.shape == interval.shape
time_current += interval * (1-mask.float())
last_state = self.compute_intensity_given_past(last_time.clone().detach(), time_current.clone().detach(), last_state)
last_time = time_current
assert self.intensity.shape == U.shape == intensity_hazard.shape
u = (U * intensity_hazard / self.intensity)
intensity_hazard = self.intensity
mask = mask | (u<1)
return time_current
For spatial sampling, here I provide two examples of attncnf and conf_gmm.
For attncnf
, add the sampling function to SelfAttentiveCNF(nn.Module)
(SelfAttentiveCNF):
def sample_spatial(self, last_times, current_times, spatial_locations, input_mask=None, aux_state=None):
"""
Args:
nsamples: int
event_times: (N, T)
current_times: (N, T, D)
input_mask: (N, T) or None
aux_state: (N,T, D_a)
Returns:
Samples from the spatial distribution at event times, of shape (nsamples, N, T, D).
"""
N, T = last_times.shape
t_embed = self.t_embedding(last_times) / math.sqrt(self.t_embedding_dim)
if aux_state is not None:
inputs = [spatial_locations, aux_state, t_embed]
else:
inputs = [spatial_locations, t_embed]
# attention layer uses (T, N, D) ordering.
inputs = [inp.transpose(0, 1) for inp in inputs]
norm_fn = max_rms_norm([a.shape for a in inputs])
x = torch.cat(inputs, dim=-1)
self.odefunc.set_shape(x.shape)
x = x.reshape(T * N, -1)
last_times = last_times.transpose(0, 1).reshape(T * N)
current_times = current_times.transpose(0, 1).reshape(T * N)
t0 = last_times + self.time_offset
t1 = current_times + self.time_offset
assert (t1 >= t0).all()
z, _ = self.cnf.integrate(t0, t1, x, torch.zeros_like(last_times), norm=norm_fn)
z = z[:, :self.dim] # (T * N, D)
return z
For cond_gmm
, you can directly find sample_spatial
function in ConditionalGMM(nn.Module)
(cond_gmm).
Dear author, after reading your research on the DSTPP model, I have been deeply inspired. However, regarding the baseline models NSTPP and DeepSTPP you used, I noticed that there is no content about prediction in their articles or code. How did you implement prediction and calculate prediction metrics in these two models? This includes time domain prediction (RMSE metric) and spatial domain prediction (Euclidean distance metric). If it's convenient, I would appreciate it if I could obtain the complete code for the two baseline models (NSTPP and DeepSTPP) you used! Thank you!