twitter-archive / torch-autograd

Autograd automatically differentiates native Torch code
Apache License 2.0
559 stars 114 forks source link

FloatTensor.cdata not currently supported by autograd when implementing SSIM loss #177

Closed farleylai closed 5 years ago

farleylai commented 5 years ago

I am working on the SSIM loss with autograd. SSIM requires several Gaussian convolutions to compute the statistics. In Torch7, a simple way to do so is to utilize nn.SpatialConvolution whose weights should be viewed as constant. However, autograd is going to raise the following error when building the computation graph due to the calling to the forward() of the spatial convolution module:

stack traceback:
    [C]: in function 'error'
    ...install/share/lua/5.1/autograd/runtime/codegen/Graph.lua:22: in function **'cdata'**
    ...kg/torch/install/share/lua/5.1/nn/SpatialConvolution.lua:80: in function 'forward'
    ./ssim.lua:58: in function 'conv2d'
    ./ssim.lua:79: in function 'ssim'
    simple.lua:30: in function 'fn'
    ...install/share/lua/5.1/autograd/runtime/codegen/Graph.lua:353: in function 'protectedFn'
    ...install/share/lua/5.1/autograd/runtime/codegen/Graph.lua:383: in function 'record'
    .../install/share/lua/5.1/autograd/runtime/codegen/init.lua:44: in function 'generateFn'
    .../install/share/lua/5.1/autograd/runtime/codegen/init.lua:140: in function 'df'
    simple.lua:39: in main chunk
    [C]: in function 'dofile'
    .../pkg/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:150: in main chunk
    [C]: at 0x00405d50

ssim() is the loss function and conv2d() is the helper function that calls the forward() of the spatial convolution module internally.

Any ideas? What is the right way/alternative when the loss function uses some Torch modules?

farleylai commented 5 years ago

The way to address this issue seems to use the functionalized version. However, when setting optimize=true, the training crashes. Otherwise, it works as expected. This feels like another issue.