twitter-archive / torch-autograd

Autograd automatically differentiates native Torch code
Apache License 2.0
559 stars 114 forks source link

Reusable AutoModule's #160

Closed vadimkantorov closed 7 years ago

vadimkantorov commented 7 years ago

I'm learning to use autograd. I tried to have a simple reusable module:

autograd = require 'autograd'

function pdist(embeddings)
    local pdist = torch.mm(embeddings, embeddings:t())
    local norm = pdist:diag():view(pdist:size(1), 1):expandAs(pdist)
    return pdist:mul(-2.0):add(norm):add(norm:t()):sqrt()
end

autograd.nn.AutoModule('AutoPairwiseL2')(pdist)
m = autograd.auto.AutoPairwiseL2() -- fails 

input = torch.rand(50, 128)
print((pdist(input) - m:forward(input)):abs():sum())

It fails with:

function: 0x40bd8900
...wigwam/prefix/bin/luajit: ...wigwam/prefix/share/lua/5.1/autograd/auto/AutoModule.lua:22: An autograd function must be specified as input to AutoModule
stack traceback:
        [C]: in function 'error'
        ...wigwam/prefix/share/lua/5.1/autograd/auto/AutoModule.lua:22: in function '__init'
        ...a_gpu101_105/.wigwam/prefix/share/lua/5.1/torch/init.lua:91: in function <...a_gpu101_105/.wigwam/prefix/share/lua/5.1/torch/init.lua:87>
        [C]: in function 'AutoPairwiseL2'
        test.lua:11: in main chunk
        [C]: in function 'dofile'
        ...105/.wigwam/prefix/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk
        [C]: at 0x00410a40

The example starts working if instead I use: m = autograd.nn.AutoModule('AutoPairwiseL2')(pdist). I thought that after I made the AutoModule call, the module will be registered along with its forward function for further usage, but it seems that I have to pass the forward function every time. Am I missing anything?

Thanks!

vadimkantorov commented 7 years ago

Ok, it seems that autograd doesn't register the function along with the class.