Open zachcp opened 1 week ago
It would be really nice to be able to print a Models 'shape' similar to the way PyTorch does. Does anyone already have a utility that does something like this?
print(AMPLIFY_MODEL) AMPLIFY( (encoder): Embedding(27, 960, padding_idx=0) (transformer_encoder): ModuleList( (0-31): 32 x EncoderBlock( (q): Linear(in_features=960, out_features=960, bias=False) (k): Linear(in_features=960, out_features=960, bias=False) (v): Linear(in_features=960, out_features=960, bias=False) (wo): Linear(in_features=960, out_features=960, bias=False) (resid_dropout): Dropout(p=0, inplace=False) (ffn): SwiGLU( (w12): Linear(in_features=960, out_features=5120, bias=False) (w3): Linear(in_features=2560, out_features=960, bias=False) ) (attention_norm): RMSNorm() (ffn_norm): RMSNorm() (ffn_dropout): Dropout(p=0, inplace=False) ) ) (layer_norm_2): RMSNorm() (decoder): Linear(in_features=960, out_features=27, bias=True) )
It would be really nice to be able to print a Models 'shape' similar to the way PyTorch does. Does anyone already have a utility that does something like this?