Open wiseodd opened 4 months ago
Currently, we require the user to do this themselves:
class MyGPT2(nn.Module): def __init__(self, tokenizer: PreTrainedTokenizer) -> None: super().__init__() config = GPT2Config.from_pretrained("gpt2") config.pad_token_id = tokenizer.pad_token_id config.num_labels = 2 self.hf_model = GPT2ForSequenceClassification.from_pretrained( "gpt2", config=config ) def forward(self, data: MutableMapping) -> torch.Tensor: device = next(self.parameters()).device input_ids = data["input_ids"].to(device) attn_mask = data["attention_mask"].to(device) output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask) return output_dict.logits
Can we provide a generic wrapper? I suspect the use case has very little variance. Maybe something like:
def huggingface_wrapper(hf_model): class WrappedModel(nn.Module): def __init__(self, hf_model: PretrainedModel): self.hf_model = hf_model def forward(self, data): output_dict = self.hf_model(**data) return output_dict.logits return WrappedModel(hf_model)
Currently, we require the user to do this themselves:
Can we provide a generic wrapper? I suspect the use case has very little variance. Maybe something like: