Closed ds-hwang closed 2 weeks ago
trainer saves model_analysis.txt to show model parameters details. e.g.
model_analysis.txt
16 [16] fc/bias 48 (3, 16) fc/weight Total number of model params: 64 State: prng_key=uint32((4,)) mesh_axes=ParameterSpec(shape=[4], dtype=<class 'jax.numpy.uint32'>, mesh_axes=PartitionSpec(None,), initializer=None, factorization=None, fan_axes=None, weight_decay_scale=None) State: model/fc/bias=float32((16,)) mesh_axes=ParameterSpec(shape=[16], dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec('model',), initializer=None, factorization=None, fan_axes=None, weight_decay_scale=None) State: model/fc/weight=float32((3, 16)) mesh_axes=ParameterSpec(shape=(3, 16), dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec(None, 'model'), initializer=None, factorization=FactorizationSpec(axes=('row', 'col')), fan_axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()), weight_decay_scale=None) State: learner/optimizer/0/trace/fc/bias=float32((16,)) mesh_axes=TensorSpec(shape=[16], dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec('model',)) State: learner/optimizer/0/trace/fc/weight=float32((3, 16)) mesh_axes=TensorSpec(shape=(3, 16), dtype=<class 'jax.numpy.float32'>, mesh_axes=PartitionSpec(None, 'model')) State: learner/optimizer/2/count=int32(()) mesh_axes=TensorSpec(shape=[], dtype=<class 'jax.numpy.int32'>, mesh_axes=PartitionSpec()) Training state size: 0.00 GiB Training state size (partitioned): 0.00 GiB Max training state size (partitioned): 0.00 GiB
Note: the functionality refers to print_model_analysis.py
trainer saves
model_analysis.txt
to show model parameters details. e.g.Note: the functionality refers to print_model_analysis.py