Closed bkonkle closed 6 months ago
Thanks for the review! I tested it, but then made the last-minute change to add with_pooling_layer
to the config, instead of passing it in as an argument. :sweat_smile: I failed to test it afterwards, and since this config property isn't found in the original Bert model config I need to default it to false
. I'll fix shortly.
I might be able to work up some Github Actions code to run those examples automatically as part of PR checks. I'll open a separate PR for that if I do.
Update: Defaulting it to false with #[config(default = false)]
doesn't actually prevent the error, since it looks for the field when it attempts to load the config from the base model file. The fact that it's not present in the base config file is why I originally structured the flag as an argument, but when Nathan suggested moving it into the config I didn't think it would be an issue. :sweat_smile: I'm working towards a solution now.
I went with pub with_pooling_layer: Option<bool>
to avoid the problems with loading the base model config, coupled with .unwrap_or(false)
to resolve the wrapped value.
Examples are working again for me:
Model variant: roberta-base
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)
Model variant: bert-base-uncased
Input: Shape { dims: [3, 64] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)
Model variant: roberta-large
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 1024] } // (Batch Size, Embedding_dim)
I'm using this in the burn-transformers
library like this: https://github.com/bkonkle/burn-transformers/blob/0.2.0/src/models/bert/sequence_classification/text_classification.rs#L130-L131
Closes #22
pad_token_idx
public for things like the batcher to use..clone()
in a few places where it isn't needed..envrc
to gitignore for direnv users..vscode
to gitignore for VS Code users.