tracel-ai / models

Models and examples built with Burn
Apache License 2.0
180 stars 24 forks source link

[Bert] Feature: Custom Model Outputs #31

Closed bkonkle closed 6 months ago

bkonkle commented 6 months ago

Closes #22

bkonkle commented 6 months ago

Thanks for the review! I tested it, but then made the last-minute change to add with_pooling_layer to the config, instead of passing it in as an argument. :sweat_smile: I failed to test it afterwards, and since this config property isn't found in the original Bert model config I need to default it to false. I'll fix shortly.

I might be able to work up some Github Actions code to run those examples automatically as part of PR checks. I'll open a separate PR for that if I do.

Update: Defaulting it to false with #[config(default = false)] doesn't actually prevent the error, since it looks for the field when it attempts to load the config from the base model file. The fact that it's not present in the base config file is why I originally structured the flag as an argument, but when Nathan suggested moving it into the config I didn't think it would be an issue. :sweat_smile: I'm working towards a solution now.

bkonkle commented 6 months ago

I went with pub with_pooling_layer: Option<bool> to avoid the problems with loading the base model config, coupled with .unwrap_or(false) to resolve the wrapped value.

Examples are working again for me:

Model variant: roberta-base
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)

Model variant: bert-base-uncased
Input: Shape { dims: [3, 64] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)

Model variant: roberta-large
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 1024] } // (Batch Size, Embedding_dim)

I'm using this in the burn-transformers library like this: https://github.com/bkonkle/burn-transformers/blob/0.2.0/src/models/bert/sequence_classification/text_classification.rs#L130-L131