Closed SugarRuy closed 6 months ago
Hi, thank you for bringing this to our attention! I ran the experiments for the paper in my research codebase, which used skimage
for reading in images, so the way we loaded and shaped input was different than it is here. When writing the cleaner codebase for releasing to the public, I switched to using torchvision
for reading in images because I found some memory leak issues with skimage
, and this bug slipped in.
This explains why outputs have not been as high-quality for some users as the ones from the paper and Satlas website.
I have pushed the fix. Thanks again!
Hi Authors,
Thanks for make it public! I am following up on your work and have found an inappropriate code implementation that may invalidate some conclusions and findings about the number of Sentinel-2 images in your research paper. Specifically, at lines 190-191 in s2-naip_dataset.py, when it reads and reshapes the data, the result of s2_img (in line 191) blends the channels and time series together, which makes the s2_tensor unusable for time-series split with indices. Thus, the model always uses all the time-series data but only a few channels.
Here is an example: Say you have an RGB image with dimensions 3x128x32, which becomes 4x3x32x32 (TxCxHxW) after reshaping (T: Time, C: Channel, H: Height, W: Width). The desired output after reshaping should be like this: -- For each T_i in the first dimension (each 3-channel image 3x32x32), the value is from (T_i^R, T_i^G, T_i^B) as a Sentinel-2 image;
However, the current implementation will result in the following outcomes: -- For the 1st T_i (i=0), its 3x32x32 is from (T_0^R, T_1^G, T_2^B); -- For the 2nd T_i (i=1), its 3x32x32 is from (T_0^G, T_1^B, T_3^R); -- For the 3rd T_i (i=2), its 3x32x32 is from (T_0^B, T_2^R, T_3^G); -- For the 4th T_i (i=3), its 3x32x32 is from (T_1^R, T_2^G, T_3^B);
In this case, a 3x32x32 array is not a Sentinel-2 image for a specific time, but a non-meaningful channel combination.
The correct way to address this could be - first reshape it to CxTxHxW, then move axis using permute to TxCxHxW:
Replace
s2_img = torch.reshape(s2_img, (-1, s2_img.shape[0], 32, 32))
withs2_img = torch.reshape(s2_img, (s2_img.shape[0], -1, 32, 32)).permute(1,0,2,3)