salesforce / ETSformer

PyTorch code for ETSformer: Exponential Smoothing Transformers for Time-series Forecasting
BSD 3-Clause "New" or "Revised" License
258 stars 40 forks source link

Error while running ETTm2.sh #4

Closed sentialx closed 2 years ago

sentialx commented 2 years ago

Hi, I'm trying to run ETTm2.sh script, but this is what I get. I've put ETTm2.csv file to dataset/ETT-small

Traceback (most recent call last):
  File "C:\python3\lib\site-packages\einops\einops.py", line 410, in reduce
    return _apply_recipe(recipe, tensor, reduction_type=reduction)
  File "C:\python3\lib\site-packages\einops\einops.py", line 233, in _apply_recipe
    _reconstruct_from_shape(recipe, backend.shape(tensor))
  File "C:\python3\lib\site-packages\einops\einops.py", line 163, in _reconstruct_from_shape_uncached
    raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))      
einops.EinopsError: Expected 2 dimensions, got 4

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\Senti\Desktop\ETSformer\run.py", line 117, in <module>
    exp.train(setting)
  File "C:\Users\Senti\Desktop\ETSformer\exp\exp_main.py", line 144, in train
    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
  File "C:\python3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Senti\Desktop\ETSformer\models\etsformer\model.py", line 72, in forward
    level, growths, seasons = self.encoder(res, x_enc, attn_mask=enc_self_mask)
  File "C:\python3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Senti\Desktop\ETSformer\models\etsformer\encoder.py", line 169, in forward
    res, level, growth, season = layer(res, level, attn_mask=None)
  File "C:\python3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Senti\Desktop\ETSformer\models\etsformer\encoder.py", line 142, in forward
    growth = self._growth_block(res)
  File "C:\Users\Senti\Desktop\ETSformer\models\etsformer\encoder.py", line 151, in _growth_block
    x = self.growth_layer(x)
  File "C:\python3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Senti\Desktop\ETSformer\models\etsformer\encoder.py", line 38, in forward
    out = torch.cat([repeat(self.es.v0, 'h d -> b 1 h d', b=b), out], dim=1)
  File "C:\python3\lib\site-packages\einops\einops.py", line 537, in repeat
    return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
  File "C:\python3\lib\site-packages\einops\einops.py", line 418, in reduce
    raise EinopsError(message + '\n {}'.format(e))
einops.EinopsError:  Error while processing repeat-reduction pattern "h d -> b 1 h d".
 Input tensor shape: torch.Size([1, 1, 8, 64]). Additional info: {'b': 32}.
 Expected 2 dimensions, got 4

This part of code seems to cause the problem:

    def forward(self, inputs):
        """
        :param inputs: shape: (batch, seq_len, dim)
        :return: shape: (batch, seq_len, dim)
        """
        b, t, d = inputs.shape
        values = self.in_proj(inputs).view(b, t, self.nhead, -1)
        values = torch.cat([repeat(self.z0, 'h d -> b 1 h d', b=b), values], dim=1)
        values = values[:, 1:] - values[:, :-1]
        out = self.es(values)
        out = torch.cat([repeat(self.es.v0, 'h d -> b 1 h d', b=b), out], dim=1)
        out = rearrange(out, 'b t h d -> b t (h d)')
        return self.out_proj(out)
gorold commented 2 years ago

Thanks for spotting this! I've just pushed a fix!

sentialx commented 2 years ago

Thanks! Btw does your model also support classification?

gorold commented 2 years ago

We developed the model for forecasting, but it is possible to explore using it for forecasting. One simple way is to only keep the encoder, and use the output representations from the encoder and feed it into a classifier. Do let us know if you ever test it out for classification!