apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 269 forks source link

Introduce `model_analysis.txt` in trainer. #824

Closed ds-hwang closed 2 weeks ago

ds-hwang commented 2 weeks ago

trainer saves model_analysis.txt to show model parameters details. e.g.


        16 [16]                 fc/bias
        48 (3, 16)              fc/weight
Total number of model params: 64

State: prng_key=uint32((4,)) mesh_axes=ParameterSpec(shape=[4], dtype=<class 'jax.numpy.uint32'>, mesh_axes=PartitionSpec(None,), initializer=None, factorization=None, fan_axes=None, weight_decay_scale=None)
State: model/fc/bias=float32((16,)) mesh_axes=ParameterSpec(shape=[16], dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec('model',), initializer=None, factorization=None, fan_axes=None, weight_decay_scale=None)
State: model/fc/weight=float32((3, 16)) mesh_axes=ParameterSpec(shape=(3, 16), dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec(None, 'model'), initializer=None, factorization=FactorizationSpec(axes=('row', 'col')), fan_axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()), weight_decay_scale=None)
State: learner/optimizer/0/trace/fc/bias=float32((16,)) mesh_axes=TensorSpec(shape=[16], dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec('model',))
State: learner/optimizer/0/trace/fc/weight=float32((3, 16)) mesh_axes=TensorSpec(shape=(3, 16), dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec(None, 'model'))
State: learner/optimizer/2/count=int32(()) mesh_axes=TensorSpec(shape=[], dtype=<class 'jax.numpy.int32'>, mesh_axes=PartitionSpec())
Training state size: 0.00 GiB
Training state size (partitioned): 0.00 GiB
Max training state size (partitioned): 0.00 GiB

Note: the functionality refers to print_model_analysis.py