torch / nngraph

Graph Computation for nn
Other
299 stars 96 forks source link

avoid nnNode:label() error for string data #152

Open kruus opened 7 years ago

kruus commented 7 years ago

Simple 1-line "robustness" PR (and associated test)

Original nngraph with luajit was giving me the following error for labelling a table containing a string value:

kruus@snake10$ th test_stringlabel.lua
mod_out[1] type is torch.DoubleTensor   
mod_out[2] type is string   
/local/kruus/torch/install/bin/luajit: /local/kruus/torch/install/share/lua/5.1/nngraph/node.lua:143:
----------> bad argument #2 to 'insert' (number expected, got string)
stack traceback:
    [C]: in function 'insert'
    /local/kruus/torch/install/share/lua/5.1/nngraph/node.lua:143: in function 'getstr'
    /local/kruus/torch/install/share/lua/5.1/nngraph/node.lua:169: in function 'label'
    /local/kruus/torch/install/share/lua/5.1/graph/init.lua:242: in function 'todot'
    test_stringlabel.lua:18: in function 'test_table_string'
       etc.

This error can happen with the following test program, and something similar has been added to tests/test_nngraph.lua

require 'nngraph'
function test_table_string()
  local inp = nn.Identity()()
  local in1 = nn.SelectTable(1)(inp)
  local in2 = nn.SelectTable(2)(inp)
  local out = nn.Linear(10,10)(in1)
  -- in2 propagates 'as is':   it could be, say, a string debug tag
  local mod = nn.gModule({inp}, {out,in2})

  local inp_tensor = torch.Tensor(10)
  local inp_string = 'Hello'
  local mod_out = mod:forward{
    torch.Tensor(10),
    "nnNode:label() should handle a string without bad argument #2 to 'insert' (number expected, got string) error"
  }
  print('mod_out[1] type is '..torch.type(mod_out[1]))
  print('mod_out[2] type is '..torch.type(mod_out[2])) -- string
  local dot0 = mod.fg:todot()
  --print(dot0)
end

test_table_string()

Luajit runs this fine if the getstr('some string') first goes into a local variable, and then into table.insert.

Dunno' why.