choderalab / pinot

Probabilistic Inference for NOvel Therapeutics
MIT License
15 stars 2 forks source link

Small modifications #36

Closed dnguyen1196 closed 4 years ago

dnguyen1196 commented 4 years ago

Fix the inplace input modification problem in GN.

Changed "h0" to "h"

dnguyen1196 commented 4 years ago

@yuanqing-wang Just realized that in Sequential, we are also doing input graph modification in-place. And that the following code in Sequential.foward calls both apply_atom_in_graph as well as GN.forward which potentially uses different interface. apply_atom_in_graph takes in 1 graph and returns a modified version while GN.forward like we want would take in a graph and (optional) node features. This will take some time to work through what interface we need. Should we follow the "standard" where we always pass (graph, node features) to avoid the in place modification problem?

        g.apply_nodes(lambda nodes: {"h": self.f_in(nodes.data["h"])})

        for exe in self.exes:
            g = getattr(self, exe)(g)
class Sequential(torch.nn.Module):
    def __init__(
        self,
        layer,
        config,
        feature_units=117,
        input_units=128,
        model_kwargs={},
    ):
        super(Sequential, self).__init__()

        # the initial dimensionality
        dim = input_units

        # record the name of the layers in a list
        self.exes = []

        # initial featurization
        self.f_in = torch.nn.Sequential(
            torch.nn.Linear(feature_units, input_units), torch.nn.Tanh()
        )

        # make a pytorch function on tensors on graphs
        def apply_atom_in_graph(fn):
            def _fn(g):
                g.apply_nodes(lambda node: {"h": fn(node.data["h"])}) <- THIS MODIFIES THE INPUT
                return g

            return _fn

        # parse the config
        for idx, exe in enumerate(config):

            try:
                exe = float(exe)

                if exe >= 1:
                    exe = int(exe)
            except BaseException:
                pass

            # int -> feedfoward
            if isinstance(exe, int):
                setattr(self, "d" + str(idx), layer(dim, exe, **model_kwargs))

                dim = exe
                self.exes.append("d" + str(idx))

            # str -> activation
            elif isinstance(exe, str):
                activation = getattr(torch.nn.functional, exe)

                setattr(self, "a" + str(idx), apply_atom_in_graph(activation))

                self.exes.append("a" + str(idx))

            # float -> dropout
            elif isinstance(exe, float):
                dropout = torch.nn.Dropout(exe)
                setattr(self, "o" + str(idx), apply_atom_in_graph(dropout))

                self.exes.append("o" + str(idx))

    def forward(self, g, return_graph=False):

        g.apply_nodes(lambda nodes: {"h": self.f_in(nodes.data["h"])})

        for exe in self.exes:
            g = getattr(self, exe)(g)

        if return_graph == True:
            return g

        h_hat = dgl.sum_nodes(g, "h")

        return h_hat
yuanqing-wang commented 4 years ago

I actually just discussed this with @maxentile and @jchodera this morning for our espaloma project. It seems that further down the road we're gonna be interested in passing graph around when we do slightly fancier things namely hierarchical message passing.

Something like this where doublets and triplets also hold parameters

image

In this case I suppose graph would be the best way to pass these kind of data.

yuanqing-wang commented 4 years ago

Let's think a little bit about how what would be the best way to do this if we do want to pass graphs.

maxentile commented 4 years ago

I actually just discussed this with @maxentile and @jchodera this morning for our espaloma project. It seems that further down the road we're gonna be interested in passing graph around when we do slightly fancier things namely hierarchical message passing. Let's think a little bit about how what would be the best way to do this if we do want to pass graphs. In this case I suppose graph would be the best way to pass these kind of data.

Choices related to how graphs are modified don't have to be coupled across these contexts.

In the other context, I think it would be safest if in-place modifications can be avoided to the extent possible (g' <- return_a_different_graph(g), not modify_graph_in_place(g)). I'm not sure I see why constructing (or passing messages using) the hierarchical graph you illustrate would require in-place modifications of the input graph.

yuanqing-wang commented 4 years ago

@maxentile if you look at the functions heterotrophs provide that can allow as to carry out those kind of operations, it's mostly in-place

https://docs.dgl.ai/api/python/heterograph.html?highlight=computing#computing-with-dglheterograph

maxentile commented 4 years ago

Message-passing in dgl updates tensors etc. in-place, yes. I may have been confused about context of your comment. (I thought your comment was about modifying the structure of the input graph in place, to add nodes.)