NESAP-STL / nesap-stl

Repository for the NESAP Extreme Spatio-Temporal Learning project
1 stars 3 forks source link

reshape_patch_back 함수를 사용했을 때 이미지가 깨지는 문제에 대한 건 #21

Open ks1101 opened 3 years ago

ks1101 commented 3 years ago

이미 아시다시피 reshape_patch 함수를 사용후 reshape_patch_back 함수를 사용해서 원래 크기로 돌려놓으면 이미지가 깨지는 문제가 있었습니다. 이 문제의 원인은 reshape_patch, reshape_patch_back 함수를 거치면서 데이터들이 다른 위치로 이동하는 것입니다. [[0,1],[2,3]] -> [[0],[1],[2],[3]] -> [[0,2],[1,3]] 이런 식으로 말입니다.

왜 다른 위치로 이동하는 가에 대한 원인이 reshape_patch, reshape_patch_back 함수에 있을 것이라고 생각해서 두 함수를 조사했습니다. 처음에는 사용하지 않는 batch를 없애볼까 했는데 안 되더라고요. 또 두 함수를 보면 numpy의 np.reshape, np.transpose 함수를 사용하는데, 이 함수들을 잘못 사용하고 있나 해서 코드를 찾아봐도 해결이 안 되었습니다.

알고보니까 원인은 np.transpose에 있었는데요. reshape_patch에서 transpose는 (0,1,2,4,6,5,3)으로 되어있는데, reshape_patch_back에서 transpose는 (0,1,2,5,4,6,3)으로 되어있어서 데이터가 어먼 곳으로 가고 이미지가 깨졌던 것 같습니다. 그래서 reshape_patch_back의 transpose를 reshape_patch의 transpose에 맞게 (0,1,2,5,3,6,4)로 변경했더니 이미지가 깨지는 현상이 없어졌습니다. 주석을 조금 더 빨리 읽어볼걸 그랬습니다...

def reshape_patch_back(patch_tensor, patch_size):
    # 중략
    a = np.reshape(patch_tensor, [batch_size, seq_length, img_channels,
                                  patch_size, patch_size,
                                  patch_height, patch_width])

    #b = np.transpose(a, [0,1,2,5,4,6,3])가 아닌
    b = np.transpose(a, [0,1,2,5,3,6,4])

    img_tensor = np.reshape(b, [batch_size, seq_length, img_channels,
                                patch_height * patch_size,
                                patch_width * patch_size])
    return img_tensor

revise_sample

그리해서 이미지가 깨지는 현상은 일단 해결했다고 생각이 듭니다만 새로운 문제가 생겼습니다. 평가와 시각화를 한 동영상, 즉 visual/visualize.py에서 만든 동영상에서 우측, 동유럽 지역의 이미지가 깨지는 문제가 있습니다. 이 경우는 reshape_patch_back의 문제로 이미지가 깨지는 것과는 다른 느낌이라 다른 해결책을 찾아야 할 것 같습니다.

0th_output 0th_target 위가 output, 아래가 target입니다.

이 문제가 왜 생기는지 알아내려고 이것저것 해봤습니다. 혹시 reshape_patch_back을 수정해서 이런 문제가 생겼나해서 수정 전 상태로 돌려놓고도 해봤는데 이 경우에는 이미지가 깨지는 와중에 한 번 더 깨지기 때문에 이 문제는 아니라는 생각을 하게 되었습니다.

https://user-images.githubusercontent.com/55044759/139203011-2f0e6010-7118-470f-8a32-246e934a9417.mp4

https://user-images.githubusercontent.com/55044759/139203016-08774eb0-5a0f-4023-8764-c4c552367f0d.mp4

output과 target, 양 쪽에 모두 문제가 있어서 원인이 무엇인지 찾기 어려운 것 같습니다. 일단 추측으로는 trainer/base.py의 evaluation 함수에서 enumerate(data_loader)가 문제인가 싶은 생각이 듭니다. output과 target이 동시에 영향을 받을 만한 구석이 없어보여서 enumerate(data_loader)를 지목했지만 이에 대해 잘 알지는 못하기에 확신은 못 하겠습니다.

조심히 다녀오세요. 기대하고 있겠습니다.