dfdx / XGrad.jl

eXpression gradients in Julia
Other
3 stars 4 forks source link

Gradient function cache & memory manager #6

Closed dfdx closed 6 years ago

dfdx commented 6 years ago

Currently gradient functions are compiled for specific argument types and sizes. If one of these changes (e.g. batch size), gradient expression becomes invalid. It would be nice to cache gradients and compute new ones for new set of arguments. This way we will be able to write:

f = ...
mem = ...
for x in batches
    df = xdiff(f; m=m, x=x)   # retrieving gradient function from cache or compiling new one
    loss, dm, dx = df(m, x; mem=mem)
    ...
end

or even something like this:

f = ...
mem = ...
for x in batches
    loss, dm, dx = ddiff(f; m=m, x=x, mem=mem) 
    ...
end

The main difficulty here is a memory manager (mem argument). Currently it's just a Dict with pre-allocated buffers for all temporary variables, but in this scenario we will need something more sophisticated, perhaps with displacement of old buffers.

dfdx commented 6 years ago

Closed by #9