materialsvirtuallab / m3gnet

Materials graph network with 3-body interactions featuring a DFT surrogate crystal relaxer and a state-of-the-art property predictor.
BSD 3-Clause "New" or "Revised" License
231 stars 59 forks source link

Training with include_states and small batch size throws error #21

Closed dgaines2 closed 2 years ago

dgaines2 commented 2 years ago

I found a strange issue when trying to train M3GNet models where I wanted to include states.

I've added a commit below with a unit test that fails.

In this test, the only changes I have made compared to test_band_gap is setting include_states = True and batch_size = 32. There's some sort of mismatch in expected dimensions between layers and it throws an error. I can include a full error traceback if it'd be helpful

chc273 commented 2 years ago

This is indeed weird. I tried to debug it by turning off the graph mode, and then the error is gone. Will take more time

dgaines2 commented 2 years ago

I think the error is somewhere in the MultiFieldReadout. In the original MEGNet, I believe the readout operation was performed on the atoms, bonds, and states, but it looks like M3GNEt is configured to just use the atom readout currently. If I manually edit MultiFieldReadout to use ignore_states=True, the error is gone on my end (and the model trains rather well anyways). I'm not sure how much model performance would improve if bonds and states were also included in the readout operation -- but I'd be curious to hear your thoughts on this too.

chc273 commented 2 years ago

Turns out to be a shape trace issue. This commit should have fixed it. 65876ac4d09d81d6a4f038240cfd0ecdde0f875c