snap-stanford / relbench

RelBench: Relational Deep Learning Benchmark
https://relbench.stanford.edu
MIT License
220 stars 41 forks source link

Add `--include_label_tables` argument to optionally include (time-censored) labels as features in the db #272

Open adobles96 opened 3 weeks ago

adobles96 commented 3 weeks ago

We add the --include_label_tables argument which can take one of three values:

  1. "none" (default): Label tables won't be added to the db. The resulting graph is as before.
  2. "task_only": Will include label information for the current task only. I.e. if we're training a GNN for the hm item-sales task, only the labels from that task will be included.
  3. "all": Includes label information for all the tasks defined on the dataset. I.e. if we're training a GNN for the hm item-sales task, the labels for hm item-sales, item-churn, user-churn, etc. will be included.

Implementation details

The implementation is quite simple. If we wish to include a particular label table, we load it as a pandas dataframe, modify it's timestamp column by adding the timedelta of the corresponding task, and add it as a new table to the relbench Database object.

The adding of timedelta to the timestamp column is crucial to avoid leakage. For example, in a task meant to predict sales of an item in the next month, the original timestamp column in the label table has the date as of which the prediction should be made. Let's say this is Jan-01, in order to predict the sales in the month of January. When we add timedelta (which equals 1 month in this case), the new value for the timestamp becomes Feb-01, which means the label information (i.e. the number of sales for the month of January) is now censored before Feb-01. In other words, the GNN will have access to the number of sales for the month of January starting on Feb-01. It is worth noting that this only works so long as we uphold the convention that timedelta is constant within a given task.

In the end, this results in a graph with new "label" nodes that hold the label values at previous times and have an edge to the relevant entity node, thus making that information available to the GNN in a single hop.