dr-aheydari / SoftAdapt

Implementation of the SoftAdapt paper (techniques for adaptive loss balancing of multi-tasking neural networks)
MIT License
24 stars 5 forks source link

Problem in the README usage example #5

Closed b-brebion closed 7 months ago

b-brebion commented 7 months ago

Hi,

Thank you for your implementation of the paper. There's an error in the usage example in the README if I'm not mistaken. Updating weights via get_component_weights() should take place outside the train_data_loader loop. Otherwise, this means that adapt_weights will be updated at every single batch of an epoch when current_epoch % epochs_to_make_updates == 0 condition is met.

Current implementation:

...

# Main training loop:
for current_epoch in range(training_epochs):
    for batch_idx, data in enumerate(train_data_loader):
        ...

        if current_epoch % epochs_to_make_updates == 0 and current_epoch != 0:
            adapt_weights = softadapt_object.get_component_weights(...)              

            values_of_component_1 = []
            values_of_component_2 = []
            values_of_component_3 = []

        ...

My proposed modification:

...

# Main training loop:
for current_epoch in range(training_epochs):
    if current_epoch % epochs_to_make_updates == 0 and current_epoch != 0:
        adapt_weights = softadapt_object.get_component_weights(...)

        values_of_component_1 = []
        values_of_component_2 = []
        values_of_component_3 = []

    for batch_idx, data in enumerate(train_data_loader):
        ...
dr-aheydari commented 7 months ago

Hi @b-brebion,

Thank you for opening the issue! Your understanding is perfect correct!

Generally, we have a training loop where we perform mini-batch gradient descent, something like:

for current_epoch in range(training_epochs):
    for batch_idx, data in enumerate(train_data_loader):
            optimizer.zero_grad()
            # Steps to pass the input to the model and get outputs and compute loss
            ...
            # Backward pass and taking a step happens per mini batch
            loss.backward()
            optimizer.step()

With the example usage provided in the README, once current_epoch % epochs_to_make_updates == 0 is met, the weights are updated for every batch during that epoch, similar to how the optimizer would take a step. Similar to mini-batch gradient descent, the idea is that updating the adaptive weights per batch allows us to learn from smaller chunks of data and potentially converge faster. However, again similar to mini-batch GD, this might introduce noise compared to doing it prior (or after). Our experiments showed that in most applications, updating the weights per batch improved convergence.

As you pointed out, one can also update the weights before the optimization loop over mini batches. This would mean that the weights will be fixed completely until the next epoch, which also makes sense.

You are more than welcome to submit a pull request with the additional explanation or another example of doing it outside of the inner training loop! I think it would be great to have that detail in the README as well.

Please let me know if what I wrote makes sense, or if you have any additional questions :)

Best, Ali

b-brebion commented 7 months ago

Thanks for your very clear answer, I understand your point of view. I will submit a PR later today to clarify this in the README.

dr-aheydari commented 7 months ago

Closing this issue now since @b-brebion submitted a PR to include an additional use case.