Closed akirasosa closed 5 years ago
Sorry for the delayed response!
I think we're doing the same thing in my code and the figure; I've just combined the edge update network and the 'interaction network' into a single function.
This is the message function code:
source_atom = source_atom_gather([atom_state, connectivity])
target_atom = target_atom_gather([atom_state, connectivity])
# Edge update network
bond_state = Concatenate()([source_atom, target_atom, bond_state])
bond_state = Dense(2*atom_features, activation='softplus')(bond_state)
bond_state = Dense(atom_features)(bond_state)
# message function
bond_state = Dense(atom_features, activation='softplus')(bond_state)
bond_state = Dense(atom_features, activation='softplus')(bond_state)
source_atom = Dense(atom_features)(source_atom)
messages = Multiply()([source_atom, bond_state])
messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])
# state transition function
messages = Dense(atom_features, activation='softplus')(messages)
messages = Dense(atom_features)(messages)
atom_state = Add()([atom_state, messages])
So here, we do the edge update first, and save the new bond state for the subsequent message function call.
Hi @pstjohn ,
Thanks for your response. Here is the code example I supposed.
source_atom = source_atom_gather([atom_state, connectivity])
target_atom = target_atom_gather([atom_state, connectivity])
# Edge update network
bond_state = Concatenate()([source_atom, target_atom, bond_state])
bond_state = Dense(2*atom_features, activation='softplus')(bond_state)
bond_state = Dense(atom_features)(bond_state)
# message function
bond_msg = Dense(atom_features, activation='softplus')(bond_state) # <------
bond_msg = Dense(atom_features, activation='softplus')(bond_msg)
source_atom = Dense(atom_features)(source_atom)
messages = Multiply()([source_atom, bond_msg])
messages = ReduceBondToAtom(reducer='sum')([messages, connectivity])
# state transition function
messages = Dense(atom_features, activation='softplus')(messages)
messages = Dense(atom_features)(messages)
atom_state = Add()([atom_state, messages])
Anyway, the result won't be changed so much.
P.S. Currently, I'm trying Kaggle competition to predict quantum molecular property. I have tried SchNet with edge update by referring this repo. How about joining if you have time? 😄
The result is similar with original SchNet. But it's possible to play with edges more flexible. Training speed is relatively slower than original SchNet.
Regards, Akira
ah ok, yes then your version might be what they intended. are you calculating the rbf expansions inside the neural network? we found that our version was a lot faster if you calculated the rbf expansions outside, and sent them in as bond_features
are you calculating the rbf expansions inside the neural network? No. I do it in data loader (I use pytorch). Still, I'm not sure why it's slower.
I close this ticket. Thanks!
Hi, I have a question about Interaction Network.
On the code below, bond_state is updated in message function. https://github.com/NREL/nfp/blob/master/examples/schnet_edgeupdate.py#L116
However, according to the original paper, the output of the Edge update Network is used as the input of the subsequent Edge Update.
Is this intended?
Thanks in advance! Akira,