fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

Not preserving spatial dimensions #39

Closed denissimom closed 2 years ago

denissimom commented 2 years ago

Good day! Thanks a lot for your efforts with this lib! Recently I encountered some problems with preserving spatial dimensions of tensor.

j = 3  
wave = 'db1'  
mode = 'symmetric'  
layer0 = DWTForward(J=j, wave=wave, mode=mode)  
layer1 = DWTInverse(wave=wave, mode=mode)  
test_input = torch.arange(27).reshape(1, 3, 3, 3).to(torch.float32)  
low, high = layer0(test_input)  
test_output = layer1((low, high))  
print(test_input.shape, test_output.shape)  

Expected to get (1, 3, 3, 3) but got (1, 3, 4, 4)

fbcotter commented 2 years ago

Ah yes, this will happen with odd length signals. I think your input is quite small, particularly for a 3 scale transform, but let's consider a larger example and fewer scales:

j = 1  
wave = 'db1'  
mode = 'symmetric'  
layer0 = DWTForward(J=j, wave=wave, mode=mode)  
layer1 = DWTInverse(wave=wave, mode=mode)  
test_input = torch.arange(27 * 9).reshape(1, 3, 9, 9).to(torch.float32)  
low, high = layer0(test_input)  
test_output = layer1((low, high))  
print(test_input.shape, test_output.shape)  
>> torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 10, 10])

Let's look at the size of low and high:

print(low.shape)
>> torch.Size([1, 3, 5, 5])
print(high[0].shape)
>> torch.Size([1, 3, 3, 5, 5])

What's happening here? As we decimate by two as part of the transform then we need the signals to be even length, so the input is effectively padded, using the periodization mode you've selected (symmetric).

If we look at the output of the above reconstructed tensor, you'll see what happens:

import numpy as np
np.set_printoptions(linewidth=120, suppress=True, precision=2)
print(test_output.numpy())

[[[[ -0.   1.   2.   3.   4.   5.   6.   7.   8.   8.]
   [  9.  10.  11.  12.  13.  14.  15.  16.  17.  17.]
   [ 18.  19.  20.  21.  22.  23.  24.  25.  26.  26.]
   [ 27.  28.  29.  30.  31.  32.  33.  34.  35.  35.]
   [ 36.  37.  38.  39.  40.  41.  42.  43.  44.  44.]
   [ 45.  46.  47.  48.  49.  50.  51.  52.  53.  53.]
   [ 54.  55.  56.  57.  58.  59.  60.  61.  62.  62.]
   [ 63.  64.  65.  66.  67.  68.  69.  70.  71.  71.]
   [ 72.  73.  74.  75.  76.  77.  78.  79.  80.  80.]
   [ 72.  73.  74.  75.  76.  77.  78.  79.  80.  80.]]

  [[ 81.  82.  83.  84.  85.  86.  87.  88.  89.  89.]
   [ 90.  91.  92.  93.  94.  95.  96.  97.  98.  98.]
   [ 99. 100. 101. 102. 103. 104. 105. 106. 107. 107.]
   [108. 109. 110. 111. 112. 113. 114. 115. 116. 116.]
   [117. 118. 119. 120. 121. 122. 123. 124. 125. 125.]
   [126. 127. 128. 129. 130. 131. 132. 133. 134. 134.]
   [135. 136. 137. 138. 139. 140. 141. 142. 143. 143.]
   [144. 145. 146. 147. 148. 149. 150. 151. 152. 152.]
   [153. 154. 155. 156. 157. 158. 159. 160. 161. 161.]
   [153. 154. 155. 156. 157. 158. 159. 160. 161. 161.]]

  [[162. 163. 164. 165. 166. 167. 168. 169. 170. 170.]
   [171. 172. 173. 174. 175. 176. 177. 178. 179. 179.]
   [180. 181. 182. 183. 184. 185. 186. 187. 188. 188.]
   [189. 190. 191. 192. 193. 194. 195. 196. 197. 197.]
   [198. 199. 200. 201. 202. 203. 204. 205. 206. 206.]
   [207. 208. 209. 210. 211. 212. 213. 214. 215. 215.]
   [216. 217. 218. 219. 220. 221. 222. 223. 224. 224.]
   [225. 226. 227. 228. 229. 230. 231. 232. 233. 233.]
   [234. 235. 236. 237. 238. 239. 240. 241. 242. 242.]
   [234. 235. 236. 237. 238. 239. 240. 241. 242. 242.]]]]

If you have to use odd length signals, you can crop the top left of the output. But otherwise it's good to have an input that's an integer multiple of 2^J