Currently gradient functions are compiled for specific argument types and sizes. If one of these changes (e.g. batch size), gradient expression becomes invalid. It would be nice to cache gradients and compute new ones for new set of arguments. This way we will be able to write:
f = ...
mem = ...
for x in batches
df = xdiff(f; m=m, x=x) # retrieving gradient function from cache or compiling new one
loss, dm, dx = df(m, x; mem=mem)
...
end
or even something like this:
f = ...
mem = ...
for x in batches
loss, dm, dx = ddiff(f; m=m, x=x, mem=mem)
...
end
The main difficulty here is a memory manager (mem argument). Currently it's just a Dict with pre-allocated buffers for all temporary variables, but in this scenario we will need something more sophisticated, perhaps with displacement of old buffers.
Currently gradient functions are compiled for specific argument types and sizes. If one of these changes (e.g. batch size), gradient expression becomes invalid. It would be nice to cache gradients and compute new ones for new set of arguments. This way we will be able to write:
or even something like this:
The main difficulty here is a memory manager (
mem
argument). Currently it's just aDict
with pre-allocated buffers for all temporary variables, but in this scenario we will need something more sophisticated, perhaps with displacement of old buffers.