Closed b-brebion closed 7 months ago
Hi @b-brebion,
Thank you for opening the issue! Your understanding is perfect correct!
Generally, we have a training loop where we perform mini-batch gradient descent, something like:
for current_epoch in range(training_epochs):
for batch_idx, data in enumerate(train_data_loader):
optimizer.zero_grad()
# Steps to pass the input to the model and get outputs and compute loss
...
# Backward pass and taking a step happens per mini batch
loss.backward()
optimizer.step()
With the example usage provided in the README, once current_epoch % epochs_to_make_updates == 0
is met, the weights are updated for every batch during that epoch, similar to how the optimizer would take a step. Similar to mini-batch gradient descent, the idea is that updating the adaptive weights per batch allows us to learn from smaller chunks of data and potentially converge faster. However, again similar to mini-batch GD, this might introduce noise compared to doing it prior (or after). Our experiments showed that in most applications, updating the weights per batch improved convergence.
As you pointed out, one can also update the weights before the optimization loop over mini batches. This would mean that the weights will be fixed completely until the next epoch, which also makes sense.
You are more than welcome to submit a pull request with the additional explanation or another example of doing it outside of the inner training loop! I think it would be great to have that detail in the README as well.
Please let me know if what I wrote makes sense, or if you have any additional questions :)
Best, Ali
Thanks for your very clear answer, I understand your point of view. I will submit a PR later today to clarify this in the README.
Closing this issue now since @b-brebion submitted a PR to include an additional use case.
Hi,
Thank you for your implementation of the paper. There's an error in the usage example in the README if I'm not mistaken. Updating weights via
get_component_weights()
should take place outside thetrain_data_loader
loop. Otherwise, this means thatadapt_weights
will be updated at every single batch of an epoch whencurrent_epoch % epochs_to_make_updates == 0
condition is met.Current implementation:
My proposed modification: