Exscientia / abodybuilder3

Apache License 2.0
40 stars 11 forks source link

Add a device map #3

Closed ideasbyjin closed 4 months ago

ideasbyjin commented 4 months ago

Hey guys, good stuff on this -- I'd recommend allowing a device map feature so you can load ProtT5 weights directly to GPU.

This does have a dependency on accelerate though:

class ProtTransEmbedder(LightningModule):
    def __init__(
        self,
        weights_dir: Path,
        model_type: str,
        device_map: str = 'auto',
    ) -> None:
        ...
        if model_type == "bert":
            self.model = BertModel.from_pretrained(
                self.weights_dir, add_pooling_layer=False, device_map = device_map
            )
        elif model_type == "t5":
            self.model = T5EncoderModel.from_pretrained(self.weights_dir, device_map = device_map)
henrykenlay commented 4 months ago

Hi ideasbyjin, thank you for your interest and suggestion. My collaborator @fdreyer made a PR based on your snippet: https://github.com/Exscientia/abodybuilder3/pull/4. If this looks okay to you we can merge it :)