aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
472 stars 73 forks source link

Generic wrapper for Huggingface transformers #206

Open wiseodd opened 4 months ago

wiseodd commented 4 months ago

Currently, we require the user to do this themselves:

class MyGPT2(nn.Module):
    def __init__(self, tokenizer: PreTrainedTokenizer) -> None:
        super().__init__()
        config = GPT2Config.from_pretrained("gpt2")
        config.pad_token_id = tokenizer.pad_token_id
        config.num_labels = 2
        self.hf_model = GPT2ForSequenceClassification.from_pretrained(
            "gpt2", config=config
        )

    def forward(self, data: MutableMapping) -> torch.Tensor:
        device = next(self.parameters()).device
        input_ids = data["input_ids"].to(device)
        attn_mask = data["attention_mask"].to(device)
        output_dict = self.hf_model(input_ids=input_ids, attention_mask=attn_mask)
        return output_dict.logits

Can we provide a generic wrapper? I suspect the use case has very little variance. Maybe something like:

def huggingface_wrapper(hf_model):
    class WrappedModel(nn.Module):
        def __init__(self, hf_model: PretrainedModel):
            self.hf_model = hf_model

        def forward(self, data):
            output_dict = self.hf_model(**data)
            return output_dict.logits

    return WrappedModel(hf_model)