Open allanzelener opened 8 years ago
For me this example actually works fine, but in a bunch of places (like here) the arguments are stored in a table and then the length is taken. Lua's strange definition of the length operator means that the length at that point could be 1 or 3, so nothing makes sense anymore...
Ah you're right, packing the args into a table is probably the issue.
The simple workaround is never using nil as a default or packing all options into a table argument. However I think that if evaluating f this way will deterministically work in Lua then evaluating df should work as well.
It looks like the best solution is to replace args = {...}
with args = table.pack(...)
. This adds a field where args.n
is the actual total number of arguments to the function including nil. (This uses select('#',...)
underneath which gives the correct length for a variable list and can also be used to select specific variables in a list without having to pack them into a table.)
For args = ...
, are you referring to something in the autograd source, or in the user's source?
autograd
source code, see the link in my previous comment.
Yes, everywhere that autograd parses the args of a function. There's also a similar issue for return values which should use table.unpack
on the result of table.pack
.
f = function(p,x,y,z)
return torch.sum(x * p.W), y, z
end
print(f(p,x,nil,2)) -- -2.2027326816008 nil 2
df = grad(f)
print(df(p,x,nil,2)) -- -2.2027326816008, remaining return values are truncated
These issues with varargs and nil values are documented here and here on the lua-users wiki.
A nil argument in a function seems to make all arguments after it nil when evaluating its derivative.