NREL / nfp

Keras layers for end-to-end learning with rdkit and pymatgen
Other
57 stars 28 forks source link

About bond_state updated in message function #2

Closed akirasosa closed 5 years ago

akirasosa commented 5 years ago

Hi, I have a question about Interaction Network.

On the code below, bond_state is updated in message function. https://github.com/NREL/nfp/blob/master/examples/schnet_edgeupdate.py#L116

However, according to the original paper, the output of the Edge update Network is used as the input of the subsequent Edge Update.

201907294511

Is this intended?

Thanks in advance! Akira,

pstjohn commented 5 years ago

Sorry for the delayed response!

I think we're doing the same thing in my code and the figure; I've just combined the edge update network and the 'interaction network' into a single function.

This is the message function code:

    source_atom = source_atom_gather([atom_state, connectivity])
    target_atom = target_atom_gather([atom_state, connectivity])

    # Edge update network
    bond_state = Concatenate()([source_atom, target_atom, bond_state])
    bond_state = Dense(2*atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features)(bond_state)

    # message function
    bond_state = Dense(atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features, activation='softplus')(bond_state)
    source_atom = Dense(atom_features)(source_atom)    
    messages = Multiply()([source_atom, bond_state])
    messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])

    # state transition function
    messages = Dense(atom_features, activation='softplus')(messages)
    messages = Dense(atom_features)(messages)
    atom_state = Add()([atom_state, messages])

So here, we do the edge update first, and save the new bond state for the subsequent message function call.

akirasosa commented 5 years ago

Hi @pstjohn ,

Thanks for your response. Here is the code example I supposed.

    source_atom = source_atom_gather([atom_state, connectivity])
    target_atom = target_atom_gather([atom_state, connectivity])

    # Edge update network
    bond_state = Concatenate()([source_atom, target_atom, bond_state])
    bond_state = Dense(2*atom_features, activation='softplus')(bond_state)
    bond_state = Dense(atom_features)(bond_state)

    # message function
    bond_msg = Dense(atom_features, activation='softplus')(bond_state)  # <------
    bond_msg = Dense(atom_features, activation='softplus')(bond_msg)
    source_atom = Dense(atom_features)(source_atom)    
    messages = Multiply()([source_atom, bond_msg])
    messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])

    # state transition function
    messages = Dense(atom_features, activation='softplus')(messages)
    messages = Dense(atom_features)(messages)
    atom_state = Add()([atom_state, messages])

Anyway, the result won't be changed so much.

P.S. Currently, I'm trying Kaggle competition to predict quantum molecular property. I have tried SchNet with edge update by referring this repo. How about joining if you have time? 😄

The result is similar with original SchNet. But it's possible to play with edges more flexible. Training speed is relatively slower than original SchNet.

Regards, Akira

pstjohn commented 5 years ago

ah ok, yes then your version might be what they intended. are you calculating the rbf expansions inside the neural network? we found that our version was a lot faster if you calculated the rbf expansions outside, and sent them in as bond_features

akirasosa commented 5 years ago

are you calculating the rbf expansions inside the neural network? No. I do it in data loader (I use pytorch). Still, I'm not sure why it's slower.

I close this ticket. Thanks!