Navidfoumani / ConvTran

This is a PyTorch implementation of ConvTran
MIT License
130 stars 7 forks source link

Conv2d? #2

Closed freebiesoft closed 1 year ago

freebiesoft commented 1 year ago

Hi, I just have one question.

I am just curious why you're using nn.Conv2d instead of nn.Conv1d in the embedding layers? i.e., here: https://github.com/Navidfoumani/ConvTran/blob/main/Models/model.py#L97

if you used nn.Conv1d you wouldn't need the awkward and computationally wasteful unsqueezing and squeezing logic around here https://github.com/Navidfoumani/ConvTran/blob/main/Models/model.py#L134

Navidfoumani commented 1 year ago

Thank you for your question. The choice of using nn.Conv2d over nn.Conv1d in our embedding layers comes from careful considerations outlined in Section 4.3 of our paper.

In summary: Accuracy with Efficiency: Breaking down 1D convolution kernels into separate temporal and spatial components, we achieve enhanced accuracy without substantial computational overhead. Similar to the Inverted Bottleneck concept, it expands input channels via temporal convolutions and then projects the result back to its original size with spatial convolutions.

Inspired by FFN in Transformers: Our strategy of expanding and projecting hidden states is also inspired by the Feed Forward Network (FFN) in transformers, designed to capture spatial interactions effectively.

This convolution type excels at capturing channel interactions in multivariate time series, as highlighted in the Dis-JointCNN paper.