google-research / jaxpruner

Apache License 2.0
206 stars 14 forks source link

Basical config inquires #6

Closed louieworth closed 1 year ago

louieworth commented 1 year ago

Hi,

When I create a new sparsity_pruner with some basic config as follows:

config.sparsity_config.algorithm = 'rigl'

config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 1
config.sparsity_config.sparsity = 0.95
config.sparsity_config.dist_type = 'erk'

My question is where is the update_freq config? Is there any inner iteration to count the training steps for pruning? Is there any source code for this, I did not find them.

From the best of my knowledge, after pruner.wrap_optax operator for specific optimizator, every call for optax.apply_updates will increase the inner step plus 1 in the pruner. Is this true?

evcu commented 1 year ago

SparseState has counter. Here is the counter incremented. So every update call the the optimizer will increase the counter.

Depending on the optax optimizer used, inner_state might have a counter too.