twitter-archive / torch-autograd

Autograd automatically differentiates native Torch code
Apache License 2.0
560 stars 115 forks source link

Input is not a package name or nn object #131

Open aleSuglia opened 8 years ago

aleSuglia commented 8 years ago

Hi to all,

I'm new to autograd and I'm trying to write a custom module that I want to integrate in an nngraph architecture. I've seen the test cases in which you use the AutoModule class to implement a Linear module. I've tried to replicate it creating a function that incapsulates the logic of my module. Here there is the code:

function repeat_no_copy(tensor, k)
    local tensor_size = tensor:size():totable()
    return torch.expand(tensor:view(1, unpack(tensor_size)), k, unpack(tensor_size))
end

function build_attentive_pooling(c)
   function attentive_pooling_fn(inputs, weight)
      local questions = inputs[1]
      local answers = inputs[2]

      assert(questions:nDimension() == 3 and answers:nDimension() == 3, "Supported batch mode only!")

      -- repeat weight matrix for each example in batch
      local repeated_weight = repeat_no_copy(weight, questions:size(1))

      -- G = tanh(Q^T U A)
      local mm_qw = A.nn.MM()
      local mm_qwa = A.nn.MM()
      local qw = mm_qw({questions, repeated_weight})
      local qwa = mm_qwa({qw, answers:transpose(2, 3)})
      local G = torch.tanh(qwa)
      local g_q = torch.max(G, 3)
      local g_a = torch.max(G, 2)

      local softmax_q = A.nn.SoftMax()
      local softmax_a = A.nn.SoftMax()
      return {softmax_q(g_q), softmax_a(g_a)}
   end

   local weight = torch.Tensor(c, c):normal()

   return A.nn.AutoModule('AttentivePooling')(attentive_pooling_fn, weight)
end

When I add it to a nn.Sequential and do a backward pass with model:backward() I get the following error:

autograd/nnwrapper.lua:291: Input is not a package name or nn object
stack traceback:
        [C]: in function 'error'
        ...gresu/torch/install/share/lua/5.1/autograd/nnwrapper.lua:291: in function 'functionalize'
        ...gresu/torch/install/share/lua/5.1/autograd/nnwrapper.lua:308: in function 'MM'
        test_autograd.lua:21: in function 'fun'
        ...all/share/lua/5.1/autograd/runtime/direct/DirectTape.lua:113: in function 'funOnly'
        ...all/share/lua/5.1/autograd/runtime/direct/DirectTape.lua:217: in function 'b'
        ...torch/install/share/lua/5.1/autograd/auto/AutoModule.lua:52: in function 'updateGradInput'
        ...a2/gresu/torch/install/share/lua/5.1/nngraph/gmodule.lua:408: in function 'neteval'
        ...a2/gresu/torch/install/share/lua/5.1/nngraph/gmodule.lua:442: in function 'updateGradInput'
        ...ia/data2/gresu/torch/install/share/lua/5.1/nn/Module.lua:31: in function 'backward'
        test_autograd.lua:131: in main chunk
        [C]: in function 'dofile'
        ...resu/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk
        [C]: at 0x00406670

Why do I receive this kind of error? Is there something wrong in my code?

Thank you in advance, Alessandro

iaalm commented 8 years ago

I have similar problem now, using CDivTable. It seems the forward pass is ok. This error appears at backward pass.

alexbw commented 8 years ago

cc @nkoumchatzky