uw-ipd / RoseTTAFold2

MIT License
160 stars 36 forks source link

'cuda:0' parameter passthrough for multi-GPU setup #15

Open avilella opened 1 year ago

avilella commented 1 year ago

Hi, Would it be possible to pass-through the GPU id when running RosettaFold2 so that in setups where there is more than one GPU, it can run in parallel? Thx.

Example where it is currently hard-coded: network/predict.py

class Predictor():
    def __init__(self, model_weights, device="cuda:0"):
        # define model name
        self.model_weights = model_weights
        self.device = device
        self.active_fn = nn.Softmax(dim=1)

        # define model & load model
        self.model = RoseTTAFoldModule(
            **MODEL_PARAM
        ).to(self.device)

        could_load = self.load_model(self.model_weights)
        if not could_load:
            print ("ERROR: failed to load model")
            sys.exit()