I am working on the SSIM loss with autograd. SSIM requires several Gaussian convolutions to compute the statistics. In Torch7, a simple way to do so is to utilize nn.SpatialConvolution whose weights should be viewed as constant. However, autograd is going to raise the following error when building the computation graph due to the calling to the forward() of the spatial convolution module:
stack traceback:
[C]: in function 'error'
...install/share/lua/5.1/autograd/runtime/codegen/Graph.lua:22: in function **'cdata'**
...kg/torch/install/share/lua/5.1/nn/SpatialConvolution.lua:80: in function 'forward'
./ssim.lua:58: in function 'conv2d'
./ssim.lua:79: in function 'ssim'
simple.lua:30: in function 'fn'
...install/share/lua/5.1/autograd/runtime/codegen/Graph.lua:353: in function 'protectedFn'
...install/share/lua/5.1/autograd/runtime/codegen/Graph.lua:383: in function 'record'
.../install/share/lua/5.1/autograd/runtime/codegen/init.lua:44: in function 'generateFn'
.../install/share/lua/5.1/autograd/runtime/codegen/init.lua:140: in function 'df'
simple.lua:39: in main chunk
[C]: in function 'dofile'
.../pkg/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:150: in main chunk
[C]: at 0x00405d50
ssim() is the loss function and conv2d() is the helper function that calls the forward() of the spatial convolution module internally.
Any ideas?
What is the right way/alternative when the loss function uses some Torch modules?
The way to address this issue seems to use the functionalized version.
However, when setting optimize=true, the training crashes.
Otherwise, it works as expected.
This feels like another issue.
I am working on the SSIM loss with autograd. SSIM requires several Gaussian convolutions to compute the statistics. In Torch7, a simple way to do so is to utilize nn.SpatialConvolution whose weights should be viewed as constant. However, autograd is going to raise the following error when building the computation graph due to the calling to the forward() of the spatial convolution module:
ssim() is the loss function and conv2d() is the helper function that calls the forward() of the spatial convolution module internally.
Any ideas? What is the right way/alternative when the loss function uses some Torch modules?