dfdx / Avalon.jl

Starter kit for legendary models
MIT License
106 stars 2 forks source link

How to freeze parameters #7

Open lbotecur opened 3 years ago

lbotecur commented 3 years ago

Hello,

I would like to know how to freeze parameters in a model, that is, how to training only a subset of parameters.

Thank you.

dfdx commented 3 years ago

The update!() function takes an optional argument ignore - a set of field paths that should not be updated. A field path is a tuple of symbols representing path to a specific parameter. For example, if your model looks like this:

mutable struct Foo
    x
    y
end

mutable struct Bar
   foo::Foo
   z
end

m = Bar(Foo(1, 2), 3)

And you want to ignore Foo's x and Bar's z, use it like this:

ignore = Set([
    (:foo, :x),
    (:z,)
])

update!(m, gm, ignore=ignore)

Note that this is a low-level and unstable API. I'm currently working on such small things, including this very specific task - freezing the parameters - but I have several uses cases and no specific design yet. I'll be grateful if you describe your use case so that I could make the API more convenient.

lbotecur commented 3 years ago

Thank you for the answer. This solution is great. My use case is just the case that you have exposed: to use a pretrained model (Foo) as part of a new model (Bar) and train this one with Foo parameters freezed. After that, to perform a fine-tuned of the model with Foo parameters unfreezed.

I don't know if there is any possibility to pass only the parameters to calculate the gradients to Yötä, in similar way that JAX done.

Thanks.

dfdx commented 3 years ago

Great, pretraining is a very important use case for Avalon, so we will definitely have a more concise syntax for freezing parameters, but exact API will arrive later, perhaps shortly after the high-level training API.

Please note that the ignore list expects full field paths, so using just (:foo,) in the example above won't have any effect. To recursively collect the list of field paths, you can use the following:

function collect_fields(obj)
    paths = []
    for p in propertynames(obj)
        subpaths = collect_fields(getproperty(obj, p))
        if !isempty(subpaths)
            for subpath in subpaths
                path = [p; subpath...]
                push!(paths, path)
            end
        else
            push!(paths, [p])
        end
    end
    return [(path...,) for path in paths]
end

I don't know if there is any possibility to pass only the parameters to calculate the gradients to Yötä, in similar way that JAX done.

I'm not sure I've got you correctly, but if you are looking for a semantics like:

f(x) = ...
gf = grad(f)
gf(x)

Unfortunately it's not possible out of the box because without concrete arguments Yota doesn't really know which method of f() to trace. Yet it should be possible to make a simple wrapper, something like:

grad_fn(f) = args -> grad(f, args...)

Is it what you were asking about?