amazon-science / earth-forecasting-transformer

Official implementation of Earthformer
Apache License 2.0
349 stars 58 forks source link

Question about the input shape of 'CuboidTransformerEncoder' in cuboid_transformer.py #61

Open paradigm21c opened 7 months ago

paradigm21c commented 7 months ago

I am currently working with multiband (4 band) satellite images (B, T, H, W, C=4) and I have a question regarding the code in line 2930:

In line 2930: self.encoder = CuboidTransformerEncoder( input_shape=(T_in, H_in, W_in, base_units), ...

I was under the impression that the shape for 'CuboidTransformerEncoder' is T, H, W, C, as mentioned in line 1696. Could you clarify the underlying logic for the difference in shape parameters between these two lines?

This difference caused an error in line 1898 assert (T, H, W, C_in) == self.input_shape

Should I follow cuboid_transformer_unet_dec.py ?

gaozhihan commented 7 months ago

Thank you for your question. The input shape asserted in CuboidTransformerEncoder in line 1898 https://github.com/amazon-science/earth-forecasting-transformer/blob/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/cuboid_transformer/cuboid_transformer.py#L1898 should match the shape of x in line 3184 https://github.com/amazon-science/earth-forecasting-transformer/blob/7732b03bdb366110563516c3502315deab4c2026/src/earthformer/cuboid_transformer/cuboid_transformer.py#L3184 which has been downsampled by the initial downsampler.

You may want to refer to the simplest test case to verify if the shapes are aligned correctly. It should work for multi-channel inputs according to my testing. You can modify input_shape and target_shape in line 24,25 for your own use. Please note that this test script is from my fork, which has not been merged into this repo.