berenslab / retinal-rl

Testing theories about retinal coding in reinforcement learning environments
GNU Affero General Public License v3.0
0 stars 1 forks source link

Implement Weight / Activation Regularization #2

Closed fabioseel closed 3 months ago

fabioseel commented 3 months ago

Moved over from other project, slight adjustments & improvements.

Pretty straightforward usage, simply modify you training loop / loss calculation as follows:

model = # define your model

########################
reg = WeightRegularization(model, p=2, weight_decay=0.01) # register model to the regularizer
########################

output = model(...)
loss = # calculate your loss

########################
loss += reg.penalty() # Just add the additional penalty to the loss
########################

# Proceed as usual
loss.backward()