chaoshangcs / GTS

Discrete Graph Structure Learning for Forecasting Multiple Time Series, ICLR 2021.
Apache License 2.0
168 stars 31 forks source link

Question about Gumbel sampling #10

Open ThinkNaive opened 3 years ago

ThinkNaive commented 3 years ago

I read that you apply a bivariate gumbel sampling in your paper, and use the generalized form gumbel softmax. Gumbel softmax takes logits (log probability) as input, while you directly use learned structure theta as input: adj = gumbel_softmax(x, temperature=temp, hard=True) (in line 234, GTS/model/pytorch/model.py) Why it worked? Thank you.

chaoshangcs commented 3 years ago

Hi, thanks for your great question. Here we considered the output of neural network as logits. This implementation is the same as the NRI code. In addition, we'd like to provide another option. We could also use the following way to get the logits:

x = torch.nn.softmax(x) logits = torch.log(x+1e-20)

If you have additional questions, please let me know. : )