M-Nauta / TCDF

Temporal Causal Discovery Framework (PyTorch): discovering causal relationships between time series
GNU General Public License v3.0
482 stars 106 forks source link

HardSoftmax missing? #4

Closed bartbussmann closed 5 years ago

bartbussmann commented 5 years ago

Hello, thanks for the interesting paper and codebase.

In the paper (Section 4.2) you state that:

After training network N_j, we apply our straightforward semi-binarization function HardSoftmax that truncates all attention scores that fall below a threshold τ_j to zero.

However, in your code, you seem to omit this and use regular soft attention instead. This might have an impact on the Causal Validation process. Am I missing something?

M-Nauta commented 5 years ago

During network training, we indeed use soft attention since hard attention is not differentiable. This is also described in Section 4.2 of the paper:

We therefore first use the soft attention approach by applying the Softmax function s to each a in each training epoch

However, in function findcauses in TCDF.py you can see that subsequently we use a threshold to only consider time series with an attention score higher than a certain threshold. That's why we talk about semi-binarization: all time series with attention scores below the threshold are not considered as potential causes, all other time series are.

bartbussmann commented 5 years ago

Ah okay, I see! From the paper, I understood that the HardSoftmax was used in the Causal Validation step as well (such that only the 'potential causes' are used to predict the value of interest during PIVM). Although it's probably not a big difference, it might be interesting to experiment with.

M-Nauta commented 5 years ago

I'm not sure if I understand you correctly, but we indeed apply PIVM only on the potential causes. This is described in the paper in Section 4.3.1. To clarify, let's take an example. Suppose we have 4 timeseries: the prices of apples, butter, cheese and milk. For predicting the price of cheese, the attention mechanism will probably find that apples are not related to cheese. By using a threshold, butter and cheese might be selected as potential causes. Secondly, we apply PIVM: we shuffle the values of both the price of butter and the price of milk and predict the price of cheese again. The result will probably be that only milk is a true cause of cheese.

bartbussmann commented 5 years ago

I'm not sure if I understand you correctly, but we indeed apply PIVM only on the potential causes. This is described in the paper in Section 4.3.1. To clarify, let's take an example. Suppose we have 4 timeseries: the prices of apples, butter, cheese and milk. For predicting the price of cheese, the attention mechanism will probably find that apples are not related to cheese. By using a threshold, butter and cheese might be selected as potential causes.

Yes, great example!

Secondly, we apply PIVM: we shuffle the values of both the price of butter and the price of milk and predict the price of cheese again.

This is where the unclarity is. Suppose that we want to find out if butter is a true cause. Therefore, we shuffle the values of butter and use these shuffled values, together with the 'original' past values of cheese, milk, and apples to predict the price of butter again. So, in this prediction, the soft attention value of apples is used, instead of the HardSoftmax value.

M-Nauta commented 5 years ago

Ah, I get your point. It would indeed be a small change (and maybe improvement!) to exclude the non-causal time series (in this case, the apple time series). However, the disadvantage is that you then need to re-train a neural network with only the potential-causal time series, since the trained network does expect the time series of apples as input as well. So this would increase computational costs substantially. But it's definitely an interesting experiment.