facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.6k stars 123 forks source link

Using higher for hyperparameter optimization #135

Open aruniyer opened 1 year ago

aruniyer commented 1 year ago

Hi,

I believe that higher could be used for hyperparameter optimization using bilevel programming. I have attempted to adapt the given meta-learning example for bilevel programming. However, I am somewhat unsure as to whether I have done it correctly. Here is a general structure of what I have done:

# Get optimizers
inner_optim = torch.optim.Adam(params=model.parameters(), lr=args.learning_rate)
outer_optim = torch.optim.Adam(params=hp.parameters(), lr=args.learning_rate)

# Training loop
num_inner_iter = args.inner_loop
for epoch in range(args.epochs):
    outer_optim.zero_grad()
    with higher.innerloop_ctx(
        model=model,
        opt=inner_optim,
        copy_initial_weights=False,
        track_higher_grads=False,
    ) as (fmodel, diffopt):
        for _ in range(num_inner_iter):
            # Forward pass
            train_out = fmodel(transformed_features, hp)
            train_loss = custom_loss(predicted=train_out, actual=train_labels)
            diffopt.step(train_loss)

        val_out = fmodel(transformed_features_val, hp)
        val_loss = custom_loss(predicted=val_out, actual=val_labels)
        val_loss.backward()
    outer_optim.step()

Does the above look correct? Or am I misunderstanding something?

NoraAl commented 1 year ago

I am working on something related to two-level optimization, and Higher could help simplify the code. I would like to see some examples of meta-optimizers using Higher.