zincware / ZnNL

A Python package for studying neural learning
Eclipse Public License 2.0
6 stars 1 forks source link

Konsti resnet implementation #105

Closed KonstiNik closed 1 year ago

KonstiNik commented 1 year ago

Implementation of a flax ResNet from HuggingFace.

SamTov commented 1 year ago

Thanks for the PR! In this case, I would have said the Black part should have been done down the chain in a separate PR. It makes it very difficult to review the larger changes to the code as there are now 109 files that need looking into. Can you highlight which modules you have changed in the ResNetPR? Alternatively, make a new PR to main where you only do the black formatting and then merge that one here.

KonstiNik commented 1 year ago
  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

SamTov commented 1 year ago
  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed

KonstiNik commented 1 year ago
  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed

Batch_stats are included in our model_state.params. But good point, I have to check whether the batch_stats get handled properly.

SamTov commented 1 year ago
  1. Did you not have to update the training procedure? The time I got this working, I needed to take into consideration the batch statistics and all these other things being passed correctly. I don't see these changes here, what was the solution?

No, I did not have to. The HF call method is directly compatible with constructing a Trainstate. After constructing it the rest is straight forward. Where exactly did you run into issues?

The call to the network has a new return signature? It should return the batch stats along with the logits and these batch stats have to be propagated to the network in the forward passes and during updates. We deal with this in the NTK calculation but unless it snuck through the last time I worked on it, there won't be batch stats passed

Batch_stats are included in our model_state.params. But good point, I have to check whether the batch_stats get handled properly.

Here for example, each time they call the model, they collect this other part of the output tuple:

logits, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        batch['image'],
        mutable=['batch_stats'],

It isn't the batch stats sorry it is this model state part. This has to be passed to other functions in order for the model to train properly. From my memory, it had something to do with ensuring the batch stats are used and updated correctly. We don't do this in normal training we just ignore this additional output.

Now in their weight update, they do the following:

 new_state = state.apply_gradients(
      grads=grads, batch_stats=new_model_state['batch_stats']
  )

so they need these stats.

They also seem to always pass it explicitly in model forward passes:

variables = {'params': state.params, 'batch_stats': state.batch_stats}
 logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)

This is in the eval step so nothing involved in training. This may not be necessary as the initial object is a dict anyway, but we do need to be sure.

KonstiNik commented 1 year ago

Further Changes need to include:

SamTov commented 1 year ago

Further Changes need to include:

  • [x] Adapt Training Strategies to the updated TrainState
  • [x] Jit larger training functions
  • [ ] Adapt trace opt to updated TrainState
  • [x] Make ResNet example more descriptive

You can remove traceopt from this, I can take care of it in my other traceopt PR. I have made enough changes to it there that this would just set the whole thing back.

KonstiNik commented 1 year ago

There was an issue with HF FlaxResNets when using smaller models with layer_type='basic' instead of layer_type='bottleneck': https://github.com/huggingface/transformers/issues/27257 It is fixed and merged into the main branch of hf-transformers. So it might be released soon.

For my examples and tests to pass, I have therefore used layer_type='bottleneck'.