Closed denissimom closed 2 years ago
Ah yes, this will happen with odd length signals. I think your input is quite small, particularly for a 3 scale transform, but let's consider a larger example and fewer scales:
j = 1
wave = 'db1'
mode = 'symmetric'
layer0 = DWTForward(J=j, wave=wave, mode=mode)
layer1 = DWTInverse(wave=wave, mode=mode)
test_input = torch.arange(27 * 9).reshape(1, 3, 9, 9).to(torch.float32)
low, high = layer0(test_input)
test_output = layer1((low, high))
print(test_input.shape, test_output.shape)
>> torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 10, 10])
Let's look at the size of low and high:
print(low.shape)
>> torch.Size([1, 3, 5, 5])
print(high[0].shape)
>> torch.Size([1, 3, 3, 5, 5])
What's happening here? As we decimate by two as part of the transform then we need the signals to be even length, so the input is effectively padded, using the periodization mode you've selected (symmetric).
If we look at the output of the above reconstructed tensor, you'll see what happens:
import numpy as np
np.set_printoptions(linewidth=120, suppress=True, precision=2)
print(test_output.numpy())
[[[[ -0. 1. 2. 3. 4. 5. 6. 7. 8. 8.]
[ 9. 10. 11. 12. 13. 14. 15. 16. 17. 17.]
[ 18. 19. 20. 21. 22. 23. 24. 25. 26. 26.]
[ 27. 28. 29. 30. 31. 32. 33. 34. 35. 35.]
[ 36. 37. 38. 39. 40. 41. 42. 43. 44. 44.]
[ 45. 46. 47. 48. 49. 50. 51. 52. 53. 53.]
[ 54. 55. 56. 57. 58. 59. 60. 61. 62. 62.]
[ 63. 64. 65. 66. 67. 68. 69. 70. 71. 71.]
[ 72. 73. 74. 75. 76. 77. 78. 79. 80. 80.]
[ 72. 73. 74. 75. 76. 77. 78. 79. 80. 80.]]
[[ 81. 82. 83. 84. 85. 86. 87. 88. 89. 89.]
[ 90. 91. 92. 93. 94. 95. 96. 97. 98. 98.]
[ 99. 100. 101. 102. 103. 104. 105. 106. 107. 107.]
[108. 109. 110. 111. 112. 113. 114. 115. 116. 116.]
[117. 118. 119. 120. 121. 122. 123. 124. 125. 125.]
[126. 127. 128. 129. 130. 131. 132. 133. 134. 134.]
[135. 136. 137. 138. 139. 140. 141. 142. 143. 143.]
[144. 145. 146. 147. 148. 149. 150. 151. 152. 152.]
[153. 154. 155. 156. 157. 158. 159. 160. 161. 161.]
[153. 154. 155. 156. 157. 158. 159. 160. 161. 161.]]
[[162. 163. 164. 165. 166. 167. 168. 169. 170. 170.]
[171. 172. 173. 174. 175. 176. 177. 178. 179. 179.]
[180. 181. 182. 183. 184. 185. 186. 187. 188. 188.]
[189. 190. 191. 192. 193. 194. 195. 196. 197. 197.]
[198. 199. 200. 201. 202. 203. 204. 205. 206. 206.]
[207. 208. 209. 210. 211. 212. 213. 214. 215. 215.]
[216. 217. 218. 219. 220. 221. 222. 223. 224. 224.]
[225. 226. 227. 228. 229. 230. 231. 232. 233. 233.]
[234. 235. 236. 237. 238. 239. 240. 241. 242. 242.]
[234. 235. 236. 237. 238. 239. 240. 241. 242. 242.]]]]
If you have to use odd length signals, you can crop the top left of the output. But otherwise it's good to have an input that's an integer multiple of 2^J
Good day! Thanks a lot for your efforts with this lib! Recently I encountered some problems with preserving spatial dimensions of tensor.
Expected to get
(1, 3, 3, 3)
but got(1, 3, 4, 4)