torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 75 forks source link

[WIP] Support arbitrary outputs in TorchMD_Net #239

Open RaulPPelaez opened 1 year ago

RaulPPelaez commented 1 year ago

Following the discussion in https://github.com/torchmd/torchmd-net/issues/198 this PR attempts to give TorchMD_Net the ability to return more than one output ("y") and its derivative ("neg_dy").

This PR is still a draft as I am trying to figure out the final design.

This PR introduces user-facing breaking changes:

New design proposed for the outputs of the model:

This is the BaseHead interface I propose:

class BaseHead(nn.Module):
    def __init__(self, dtype=torch.float32):
        super(BaseHead, self).__init__()
        self.dtype = dtype

    def reset_parameters(self):
        pass

    def per_point(self, point_features, results, z, pos, batch, extra_args):
        return point_features, results

    def per_sample(self, point_features, results, z, pos, batch, extra_args):
        return point_features, results

Where the forward call of TorchMD_Net would go like this:

        results = {}
        point_features = self.representation_model(z, pos, batch, q=q, s=s)
        for head in self.head_list:
            point_features, results = head.per_point(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
        for head in self.head_list:
            point_features, results = head.per_sample(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)

Each head is free to add a new key to result, modify the point_features or the contents of result (i.e add to the energy). For instance, the EnergyHead:

class EnergyHead(BaseHead):
    def __init__(self,
                 hidden_channels,
                 activation="silu",
                 dtype=torch.float32):
        super(EnergyHead, self).__init__(dtype=dtype)
        act_class = act_class_mapping[activation]
        self.output_network = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
            act_class(),
            nn.Linear(hidden_channels // 2, 1, dtype=dtype),
        )
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.output_network[0].weight)
        self.output_network[0].bias.data.fill_(0)
        nn.init.xavier_uniform_(self.output_network[2].weight)
        self.output_network[2].bias.data.fill_(0)

    def per_point(self, point_features, results, z, pos, batch, extra_args):
        results["energy"] = self.output_network(point_features)
        return point_features, results

    def per_sample(self, point_features, results, z, pos, batch, extra_args):
        results["energy"] = scatter(results["energy"], batch, dim=0)
        return point_features, results

There are some challenges I have still to deal with:

Tasks:

giadefa commented 1 year ago

Why are we changing these things? We agreed on them a while back.

g

On Fri, Nov 3, 2023 at 1:45 PM Raul @.***> wrote:

Following the discussion in #198 https://github.com/torchmd/torchmd-net/issues/198 this PR attempts to give TorchMD_Net the ability to return more than one output ("y") and its derivative ("neg_dy").

This PR is still a draft as I am trying to figure out the final design. This PR introduces user-facing breaking changes:

  • It changes some names in the configuration file (for instance Scalar is no longer a thing). Although a conversion could be made when processing the configuration.
  • The Datasets must provide "energy", "force" instead of "y", "neg_dy".
  • TorchMD_Net is expected to compute always at least energy, instead of a generic label called "y". Maybe I am missing some usecases here, so we will see...

New design proposed for the outputs of the model:

  • TorchMD_Net is composed of a representation model + an arbitrary number of heads stacked sequentially.
  • There is no distinction between a Prior and what used to be an OutputModel, they are all Heads now.
  • The EnergyHead is always the first one and the ForceHead the last (if derivative=True)
  • There is some level of customization akin to the Heads for computing the loss of each output and reducing the total loss.
  • The user provides a list of weights (like y_weight, neg_dy_weight now) for each model output that should be considered for the loss computation.

This is the BaseHead interface I propose:

class BaseHead(nn.Module): def init(self, dtype=torch.float32): super(BaseHead, self).init() self.dtype = dtype

def reset_parameters(self):
    pass

def per_point(self, point_features, results, z, pos, batch, extra_args):
    return point_features, results

def per_sample(self, point_features, results, z, pos, batch, extra_args):
    return point_features, results

Where the forward call of TorchMD_Net would go like this:

    results = {}
    point_features = self.representation_model(z, pos, batch, q=q, s=s)
    for head in self.head_list:
        point_features, results = head.per_point(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
    for head in self.head_list:
        point_features, results = head.per_sample(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)

Each head is free to add a new key to result, modify the point_features or the contents of result (i.e add to the energy). For instance, the EnergyHead:

class EnergyHead(BaseHead): def init(self, hidden_channels, activation="silu", dtype=torch.float32): super(EnergyHead, self).init(dtype=dtype) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), act_class(), nn.Linear(hidden_channels // 2, 1, dtype=dtype), ) self.reset_parameters()

def reset_parameters(self):
    nn.init.xavier_uniform_(self.output_network[0].weight)
    self.output_network[0].bias.data.fill_(0)
    nn.init.xavier_uniform_(self.output_network[2].weight)
    self.output_network[2].bias.data.fill_(0)

def per_point(self, point_features, results, z, pos, batch, extra_args):
    results["energy"] = self.output_network(point_features)
    return point_features, results

def per_sample(self, point_features, results, z, pos, batch, extra_args):
    results["energy"] = scatter(results["energy"], batch, dim=0)
    return point_features, results

There are some challenges I have still to deal with:

  • Not sure how happy TorchScript is going to be with this.
  • Not sure ho the user should specify a list of predefined heads. Perhaps something like an option head_list: energy_head, coulomb_prior, some_other_prior, charge_head, some_charge_prior, force_head

Tasks:

  • Adapt TorchMD_Net
    • Make Equivariant versions of the heads for ET.
  • Adapt LNNP
  • Adapt Datasets
  • Make priors into heads
  • Generalize the loss computation
  • Handle user input
  • Update tests

You can view, comment on, or merge this pull request online at:

https://github.com/torchmd/torchmd-net/pull/239 Commit Summary

File Changes

(2 files https://github.com/torchmd/torchmd-net/pull/239/files)

Patch Links:

— Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/pull/239, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOXKN6LGQRLERLQG3DTYCTRNBAVCNFSM6AAAAAA64MBGAKVHI2DSMVQWIX3LMV43ASLTON2WKOZRHE3TMMJTGI4TMNY . You are receiving this because you are subscribed to this thread.Message ID: @.***>

peastman commented 1 year ago

I was thinking of something a bit more generic than this. You can define an arbitrary set of output heads and loss terms. I imagine the description in the config file looking something like this.

output_heads:
  - scalar:
    name: energy
  - coulomb  # the Coulomb head is hardcoded to output a scalar "energy" and a vector "charges"
losses:
  - l2
    output: energy  # since multiple heads have "energy" outputs, they get summed before computing the loss
    dataset_field: y
    weight: 1.0
  - gradient_l2
    output: energy
    dataset_field: neg_dy
    weight: 0.1
  - l2
    output: charges
    dataset_field: mbis_charges
    weight: 0.1

The configuration for a totally different sort of model might look like this.

output_heads:
  - scalar
    name: solubility
losses:
  - l2
    output: solubility
    dataset_field: solubility
    # if weight is omitted, it defaults to 1
peastman commented 1 year ago

Is it ok if I try implementing the design described above?

RaulPPelaez commented 1 year ago

Hi Peter, I am working on it but I have not had much time, sorry about that. It is fine if you want to give it a try, feel free to open a new PR if/when you have something and we can iterate. Would love to see your take. I like your design very much, btw. Perhaps with the exception that I would rather the gradient be a property of the heads instead of the losses. Thinking about how an inference configuration should work, when reading it I would not immediately look at the loss section.

giadefa commented 1 year ago

We have already an implementation of what I think it's what you need, so maybe wait that Raul finds out what is that we are already doing.

G

On Wed, Nov 15, 2023, 07:46 Raul @.***> wrote:

Hi Peter, I am working on it but I have not had much time, sorry about that. It is fine if you want to give it a try, feel free to open a new PR if/when you have something and we can iterate. Would love to see your take. I like your design very much, btw. Perhaps with the exception that I would rather the gradient be a property of the heads instead of the losses. Thinking about how an inference configuration should work, when reading it I would not immediately look at the loss section.

— Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/pull/239#issuecomment-1811893280, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOT6NYV45ED7YZZLDGDYERQMLAVCNFSM6AAAAAA64MBGAKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMJRHA4TGMRYGA . You are receiving this because you commented.Message ID: @.***>