songjhaha / PaddleChainRules.jl

1 stars 0 forks source link

add constructor of FCNet and GeneralNet #2

Closed songjhaha closed 2 years ago

songjhaha commented 2 years ago

add constructor of FCNet

PaddleFCNet(dim_ins, dim_outs, num_layers, hidden_size; dtype="float32", activation="sigmoid")

and for general Net, the rough solution is before every forward pass, copy the value of params to paddle model, like:

struct PaddleStatelessGeneralNet<:PaddleStatelessModule
    NN::PyObject
end

function (stateless_module::PaddleStatelessGeneralNet)(params::Vector, inputs; kwinputs...)
    map((p,p_new)->p.set_value(p_new), stateless_module.NN.parameters(), params)
    out = stateless_module.NN(inputs)
    return out
end