twitter-archive / torch-autograd

Autograd automatically differentiates native Torch code
Apache License 2.0
560 stars 115 forks source link

Nil arguments prevent parsing later non-nil arguments #148

Open allanzelener opened 8 years ago

allanzelener commented 8 years ago

A nil argument in a function seems to make all arguments after it nil when evaluating its derivative.

grad = require 'autograd'
x = torch.randn(5,10)
p = {W = torch.randn(10,2) }
f = function(p, x, opt1, opt2)
  y = torch.sum(x * p.W) + opt2
  if opt1 then y = y + opt1 end
  return y
end
df = grad(f)
df(p, x, 3, 5) -- works
df(p, x, nil, 5) -- opt2 is nil: autograd/support.lua:68: attempt to perform arithmetic on a nil value
bartvm commented 8 years ago

For me this example actually works fine, but in a bunch of places (like here) the arguments are stored in a table and then the length is taken. Lua's strange definition of the length operator means that the length at that point could be 1 or 3, so nothing makes sense anymore...

allanzelener commented 8 years ago

Ah you're right, packing the args into a table is probably the issue.

The simple workaround is never using nil as a default or packing all options into a table argument. However I think that if evaluating f this way will deterministically work in Lua then evaluating df should work as well.

It looks like the best solution is to replace args = {...} with args = table.pack(...). This adds a field where args.n is the actual total number of arguments to the function including nil. (This uses select('#',...) underneath which gives the correct length for a variable list and can also be used to select specific variables in a list without having to pack them into a table.)

alexbw commented 8 years ago

For args = ..., are you referring to something in the autograd source, or in the user's source?

bartvm commented 8 years ago

autograd source code, see the link in my previous comment.

allanzelener commented 8 years ago

Yes, everywhere that autograd parses the args of a function. There's also a similar issue for return values which should use table.unpack on the result of table.pack.

f = function(p,x,y,z)
  return torch.sum(x * p.W), y, z
end
print(f(p,x,nil,2)) -- -2.2027326816008      nil    2
df = grad(f)
print(df(p,x,nil,2)) -- -2.2027326816008, remaining return values are truncated

These issues with varargs and nil values are documented here and here on the lua-users wiki.