Open squidszyd opened 7 years ago
Yes. I guess what you need is the following. x is a table of parameters.
local function sgd_custom(opfunc, x, config, state)
-- (0) get/update state
local config = config or {}
local state = state or config
local lr = config.learningRate or 1e-3
local lrd = config.learningRateDecay or 0
local wd = config.weightDecay or 0
local mom = config.momentum or 0
local damp = config.dampening or mom
local nesterov = config.nesterov or false
local lrs = config.learningRates
local wds = config.weightDecays
state.evalCounter = state.evalCounter or 0
local nevals = state.evalCounter
assert(not nesterov or (mom > 0 and damp == 0), "Nesterov momentum requires a momentum and zero dampening")
-- (1) evaluate f(x) and df/dx
local fx,dfdx = opfunc(x)
-- (2) weight decay with single or individual parameters
if wd ~= 0 then
for i=1, #dfdx do
dfdx[i]:add(wd, x[i])
end
elseif wds then
if not state.decayParameters then
state.decayParameters = torch.Tensor():typeAs(x):resizeAs(dfdx)
end
state.decayParameters:copy(wds):cmul(x)
dfdx:add(state.decayParameters)
end
-- (3) apply momentum
if mom ~= 0 then
if not state.dfdx then
state.dfdx = {}
for i=1, #dfdx do
state.dfdx[i] = torch.Tensor():typeAs(dfdx[i]):resizeAs(dfdx[i]):copy(dfdx[i])
end
else
for i=1, #dfdx do
state.dfdx[i]:mul(mom):add(1-damp, dfdx[i])
end
end
if nesterov then
for i=1,#dfdx do
dfdx[i]:add(mom, state.dfdx[i])
end
else
for i=1, #dfdx do
dfdx[i] = state.dfdx[i]
end
end
end
-- (4) learning rate decay (annealing)
local clr = lr / (1 + nevals*lrd)
-- (5) parameter update with single or individual learning rates
if lrs then
if not state.deltaParameters then
state.deltaParameters = torch.Tensor():typeAs(x):resizeAs(dfdx)
end
state.deltaParameters:copy(lrs):cmul(dfdx)
x:add(-clr, state.deltaParameters)
else
for i=1, #x do
x[i]:add(-clr, dfdx[i])
end
end
-- (6) update evaluation counter
state.evalCounter = state.evalCounter + 1
-- return x*, f(x) before optimization
return x,{fx}
end
Or you may just put these models under a same custom nn.Container
.
@taineleau Thank you! I've also found that combine_all_parameters could solve my issue.
I'm currently working on a project where there are two modules that need to be optimized. And these two modules are somewhat relative to each other. I'm wondering if it is possible to optimize them together using optim? For example, could I write a feval function whose input is a table of parameters: { paramFromModule1, paramFromModule2 } and returns a table of grads: { gradsFromModule1, gradsFromModule2 }?