ant-research / EasyTemporalPointProcess

EasyTPP: Towards Open Benchmarking Temporal Point Processes
https://ant-research.github.io/EasyTemporalPointProcess/
Apache License 2.0
239 stars 25 forks source link

Rationale of multiple (num_exp) draws of the exponential time delta per event in the draw_next_time_one_step() thinning algorithm #34

Closed HaochenWang1243 closed 2 months ago

HaochenWang1243 commented 2 months ago

Hi. I'm looking at the draw_next_time_one_step method of EventSampler class in torch_thinning.py. I see that for any single event that occurred at time $t$ in the batch, num_exp samples of $\Delta \sim \exp(\lambda^)$ RV are drawn. Then, thesample_accept method finds the sample that gives the smallest criterion $\frac{u \lambda^\}{\lambda_k(t+\Delta)}$ among all samples of $\Delta$. If it's smaller than 1; accept it as a sample of next-event-time; Otherwise, take the max sampled $\Delta$ plus 1 as the sampled next-event-time.

My doubt is, drawing multiple one-step-ahead time deltas and accepting the one that gives the smallest criterion doesn't seem to simulate a non-homogeneous poisson process anymore. For example, if the model is a plain Hawkes process (assumption for simplicity), I think if we increase num_exp while keeping everything else unchanged, we certainly will observe smaller accepted one-step-ahead $\Delta$ samples as such samples make the criterion $\frac{u \lambda^*}{\lambda_k(t+\Delta)}$ smaller by making the denominator (which decays as $\Delta$ grows) bigger. For models with non-monotonic change of intensity in the intervals between events, such as NHP, it'll not be such a simple convergence though I think, for which I'm unsure if this method of multiple exp draws work.

So why is it ok to do the thinning like this, instead of drawing consecutive $\Delta$ samples and adding them to the last event time $t$ until the criterion $\leq 1$, like the "repeat...until" procedure of Algorithm 2 in the NHP paper? Is it a approximation (like how intensity boundaries are computed)? My understanding of the math or the code can be wrong but I have been confused by this for several days so I'm asking here. Thank you for your help!

iLampard commented 2 months ago

Hi,

I think your comment makes sense. This implementation, i believe, is kind of an approximation, and the code is inherited from AttNHP.

If you have time, can you improve this code to follow what is proposed in NHP, and see if the result may change(i guess change is marginal) and the efficiency (maybe much slower)?

If the modification does not cause efficiency issues, i will be happy to see merge a PR.

HaochenWang1243 commented 2 months ago

Hi thank you for you response! I'm happy to make the change and test but I'm still trying to figure out how to properly set other config parameters, so hopefully I can come back later when I'm more confident :)

iLampard commented 2 months ago

Np. The thinning algo is the most challenging part to implement. Also you can see there is a multi-step generation issue that i still have not fixed. There is probably some problem in this implementation fo batch-wise sampling.

ivan-chai commented 2 months ago

Hi! I can suggest an alternative implementation of thinning at HoTPP

iLampard commented 2 months ago

Hi! I can suggest an alternative implementation of thinning at HoTPP

Great work. We will update the code based on yours.

ivan-chai commented 2 months ago

Hi! I can suggest an alternative implementation of thinning at HoTPP

Great work. We will update the code based on yours.

Thank you! Please, cite the original code in the docstring or comment, if you will be using HoTPP implementation in some way.

HaochenWang1243 commented 2 months ago

Hi. I know HoTPP should be a more promising option but I need to use EasyTPP now so I spent some time changing the current thinning code to what's described in NHP paper's Algorithm 2 myself. I've made a PR. My tests on the provided datasets show that my implementation consistently outperforms the current one on both one-step-ahead prediction RMSE and efficiency. For example, using the current implementation (picking the exp sample of smallest criterion), running the taxi dataset on the config NHP_train in experiment_config.yaml gives the results below (I'm printing the amount of time each draw_next_time_one_step takes call takes as you can see):

hhhqss@hhhs-MacBook-Air examples % python train_nhp.py
2024-08-08 20:39:36,296 - config.py[pid:18330;line:33:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig
2024-08-08 20:39:36,302 - runner_config.py[pid:18330;line:164:update_config] - CRITICAL: train model NHP using CPU with torch backend
2024-08-08 20:39:36,314 - runner_config.py[pid:18330;line:36:__init__] - INFO: Save the config to ./checkpoints/18330_140704444761728_240808-203936/NHP_train_output.yaml
2024-08-08 20:39:36,315 - base_runner.py[pid:18330;line:170:save_log] - INFO: Save the log to ./checkpoints/18330_140704444761728_240808-203936/log
2024-08-08 20:39:38,512 - tpp_runner.py[pid:18330;line:60:_init_model] - INFO: Num of model parameters 59146
/Users/hhhqss/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
  warnings.warn(
2024-08-08 20:39:46,705 - base_runner.py[pid:18330;line:92:train] - INFO: Data 'taxi' loaded...
2024-08-08 20:39:46,705 - base_runner.py[pid:18330;line:97:train] - INFO: Start NHP training...
2024-08-08 20:39:49,654 - tpp_runner.py[pid:18330;line:96:_train_model] - INFO: [ Epoch 0 (train) ]: train loglike is -1.8306055527528788, num_events is 50454
2024-08-08 20:39:59,254 - torch_basemodel.py[pid:18330;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):9.396914958953857
2024-08-08 20:39:59,464 - tpp_runner.py[pid:18330;line:107:_train_model] - INFO: [ Epoch 0 (valid) ]:  valid loglike is -1.6224344222827596, num_events is 7204, acc is 0.8430038867295947, rmse is 0.362656055077529
2024-08-08 20:40:32,508 - torch_basemodel.py[pid:18330;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):32.72161293029785
2024-08-08 20:40:35,118 - torch_basemodel.py[pid:18330;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):2.1267929077148438
2024-08-08 20:40:35,219 - tpp_runner.py[pid:18330;line:122:_train_model] - INFO: [ Epoch 0 (test) ]: test loglike is -1.6084953487994105, num_events is 14420, acc is 0.8522884882108183, rmse is 0.3709726238506591
2024-08-08 20:40:35,219 - tpp_runner.py[pid:18330;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.6224 (updated at epoch-0), best updated at this epoch
2024-08-08 20:40:38,282 - tpp_runner.py[pid:18330;line:96:_train_model] - INFO: [ Epoch 1 (train) ]: train loglike is -1.493789516192968, num_events is 50454
2024-08-08 20:40:46,362 - torch_basemodel.py[pid:18330;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):7.884713888168335
2024-08-08 20:40:46,514 - tpp_runner.py[pid:18330;line:107:_train_model] - INFO: [ Epoch 1 (valid) ]:  valid loglike is -1.246168578220433, num_events is 7204, acc is 0.8629927817878956, rmse is 0.36263942498794405
2024-08-08 20:41:13,904 - torch_basemodel.py[pid:18330;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):27.073447942733765
2024-08-08 20:41:30,474 - torch_basemodel.py[pid:18330;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):15.272824764251709
2024-08-08 20:41:30,886 - tpp_runner.py[pid:18330;line:122:_train_model] - INFO: [ Epoch 1 (test) ]: test loglike is -1.2258233397083045, num_events is 14420, acc is 0.8748266296809986, rmse is 0.37095806510138896
2024-08-08 20:41:30,887 - tpp_runner.py[pid:18330;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.2462 (updated at epoch-1), best updated at this epoch
2024-08-08 20:41:30,898 - base_runner.py[pid:18330;line:104:train] - INFO: End NHP train! Cost time: 1.736m

Using my implementation of Algorithm 2 under the same config, the result is:

hhhqss@hhhs-MacBook-Air examples % python train_nhp.py
2024-08-08 20:34:38,080 - config.py[pid:18006;line:33:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig
2024-08-08 20:34:38,081 - runner_config.py[pid:18006;line:164:update_config] - CRITICAL: train model NHP using CPU with torch backend
2024-08-08 20:34:38,086 - runner_config.py[pid:18006;line:36:__init__] - INFO: Save the config to ./checkpoints/18006_140704444761728_240808-203438/NHP_train_output.yaml
2024-08-08 20:34:38,087 - base_runner.py[pid:18006;line:170:save_log] - INFO: Save the log to ./checkpoints/18006_140704444761728_240808-203438/log
2024-08-08 20:34:39,924 - tpp_runner.py[pid:18006;line:60:_init_model] - INFO: Num of model parameters 59146
/Users/hhhqss/Library/Python/3.8/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
  warnings.warn(
2024-08-08 20:34:48,340 - base_runner.py[pid:18006;line:92:train] - INFO: Data 'taxi' loaded...
2024-08-08 20:34:48,341 - base_runner.py[pid:18006;line:97:train] - INFO: Start NHP training...
2024-08-08 20:34:50,675 - tpp_runner.py[pid:18006;line:96:_train_model] - INFO: [ Epoch 0 (train) ]: train loglike is -1.8306055527528788, num_events is 50454
2024-08-08 20:34:57,606 - torch_basemodel.py[pid:18006;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):6.782813787460327
2024-08-08 20:34:57,824 - tpp_runner.py[pid:18006;line:107:_train_model] - INFO: [ Epoch 0 (valid) ]:  valid loglike is -1.6224344222827596, num_events is 7204, acc is 0.8428650749583565, rmse is 0.33080250566940167
2024-08-08 20:35:17,381 - torch_basemodel.py[pid:18006;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):19.326049089431763
2024-08-08 20:35:20,364 - torch_basemodel.py[pid:18006;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):2.4894449710845947
2024-08-08 20:35:20,542 - tpp_runner.py[pid:18006;line:122:_train_model] - INFO: [ Epoch 0 (test) ]: test loglike is -1.6084953487994105, num_events is 14420, acc is 0.8521497919556172, rmse is 0.3386021943378324
2024-08-08 20:35:20,543 - tpp_runner.py[pid:18006;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.6224 (updated at epoch-0), best updated at this epoch
2024-08-08 20:35:23,670 - tpp_runner.py[pid:18006;line:96:_train_model] - INFO: [ Epoch 1 (train) ]: train loglike is -1.493789516192968, num_events is 50454
2024-08-08 20:35:30,518 - torch_basemodel.py[pid:18006;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):6.638453960418701
2024-08-08 20:35:30,677 - tpp_runner.py[pid:18006;line:107:_train_model] - INFO: [ Epoch 1 (valid) ]:  valid loglike is -1.246168578220433, num_events is 7204, acc is 0.862576346474181, rmse is 0.33324353030724874
2024-08-08 20:35:52,634 - torch_basemodel.py[pid:18006;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):21.6252920627594
2024-08-08 20:35:56,763 - torch_basemodel.py[pid:18006;line:182:predict_one_step_at_every_event] - CRITICAL: draw_next_time_one_step takes time (seconds):3.633074998855591
2024-08-08 20:35:57,085 - tpp_runner.py[pid:18006;line:122:_train_model] - INFO: [ Epoch 1 (test) ]: test loglike is -1.2258233397083045, num_events is 14420, acc is 0.8746879334257975, rmse is 0.34193676160808867
2024-08-08 20:35:57,086 - tpp_runner.py[pid:18006;line:124:_train_model] - CRITICAL: current best loglike on valid set is -1.2462 (updated at epoch-1), best updated at this epoch
2024-08-08 20:35:57,095 - base_runner.py[pid:18006;line:104:train] - INFO: End NHP train! Cost time: 1.146m

NHP_train_output.yaml generated in the checkpoints folder after the runs (identical for the 2 experiments above):

data_config:
  train_dir: easytpp/taxi
  valid_dir: easytpp/taxi
  test_dir: easytpp/taxi
  data_format: easytpp/taxi
  data_specs:
    num_event_types: 10
    pad_token_id: 10
    padding_side: right
    truncation_side: null
    padding_strategy: null
    truncation_strategy: null
    max_len: null
base_config:
  stage: train
  backend: torch
  dataset_id: taxi
  runner_id: std_tpp
  model_id: NHP
  base_dir: ./checkpoints/
  specs:
    log_folder: ./checkpoints/88267_140704444761728_240806-135500
    saved_model_dir: ./checkpoints/88267_140704444761728_240806-135500/models/saved_model
    saved_log_dir: ./checkpoints/88267_140704444761728_240806-135500/log
    output_config_dir: ./checkpoints/88267_140704444761728_240806-135500/NHP_train_output.yaml
model_config:
  rnn_type: LSTM
  hidden_size: 64
  time_emb_size: 16
  num_layers: 2
  mc_num_sample_per_step: 20
  sharing_param_layer: false
  loss_integral_num_sample_per_step: 20
  dropout_rate: 0.0
  use_ln: false
  thinning:
    num_seq: 10
    num_sample: 1
    num_exp: 500
    look_ahead_time: 10
    patience_counter: 5
    over_sample_rate: 5
    num_samples_boundary: 5
    dtime_max: 5
    num_step_gen: 1
  num_event_types_pad: 11
  num_event_types: 10
  event_pad_index: 10
  model_id: NHP
  pretrained_model_dir: null
  gpu: -1
  model_specs:
    beta: 1
    bias: true
trainer_config:
  seed: 2019
  gpu: -1
  batch_size: 256
  max_epoch: 2
  shuffle: false
  optimizer: adam
  learning_rate: 0.001
  valid_freq: 1
  use_tfb: false
  metrics:
  - acc
  - rmse
ivan-chai commented 2 months ago

Hi! Algorithm 2 from NHP includes the repeat loop. I don't see the loop in the PR version. Do you think, the implemented approach is a valid alternative? Can it be prooved?

HaochenWang1243 commented 2 months ago

Hi, I think the correctness can be proved. I did avoid for loops because they're usually inefficient. I came up with a small trick that makes torch.argmax serve the role of for for searching for the first accepted exp/delta sample. Here's a walkthrough of my changes. Firstly, https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L198 accumulates the exp (delta) samples at the last dimension of the exp_numbers tensor. I copied this line from https://github.com/ant-research/EasyTemporalPointProcess/issues/13#issuecomment-1837317744.

Now we want to accept the first (in left-to-right order) of these accumulated exp samples that gives criterion < 1 as required by Algorithm 2. If non is accepted, use self.dtime_max. I implemented this whole process in the sample_accept method, which directly returns the output [batch_size, seq_len, num_samples] next-event-time samples tensor now, instead of what the original implementation of sample_accept returns. The inputs required by the method, unif_numbers, sample_rate, total_intensities are calculated by original code already and I didn't change them. However, I did added exp_numbers as a input to the method to make things easier.

Now let me go over sample_accept line by line: https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L147 This criterion calculation is from the original code. Next, https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L150 creates a binary mask tensor of same shape with criterion ([batch_size, seq_len, num_sample, num_exp]), where the criterion values smaller than 1 are labeled 1 and the rest are labeled 0. Next, https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L153 creates a binary mask tensor of shape [batch_size, seq_len, num_sample], where each entry in this mask shows whether all the exp samples drawn for a particular next-event-time sample give criterion >= 1, i.e. no exp sample can be accepted. We will use this mask later. Next, https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L156 finds the index of the first accepted exp sample for each next-event-time sample we want. The trick here is that torch.argmax would return the index of the first occurrence of the largest value if multiple equally large values are present in the dimension it searches along. Since masked_crit_less_than_1 is a binary mask tensor, torch.argmax finds the index of first occurrence of 1 in the dimension of exp samples (the 4th dimension). This is exactly the sample we want to accept. Next, https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L160 collects the exp samples from the indices we just found from the criterion mask. Here's an explanation of torch.gather There's one last thing to consider: if none of the exp samples drawn for a certain next-event-time can be accepted, the argmax would be taken over [0,0,...,0] and returns index 0 anyways. We filter out the exp samples that correspond to such argmax results and replace them with self.dtime_max using torch.where: https://github.com/ant-research/EasyTemporalPointProcess/blob/6e8c634437adfc5c20f32f304cc9a8bd84bc3e90/easy_tpp/model/torch_model/torch_thinning.py#L163 And the last line squeezes the last dimension out to form a output of shape [batch_size, seq_len_num_sample].

My explanation power is limited so it would be great if you can try running this sample_accept on a smaller 4D tensor, say, of size [2,3,4,5] to see how it works!

ivan-chai commented 2 months ago

Thank you for the detailed explanation! It seems convincing.