fabiannagel / schnax

An implementation of SchNet in JAX and JAX-MD.
16 stars 2 forks source link

Break down model into representation and output #5

Open fabiannagel opened 2 years ago

fabiannagel commented 2 years ago

A schnax/SchNetPack model consists out of two main components: 1) A SchNet representation block (embeddings, distance expansions, interactions) 2) An atomwise output block (simple 2-layer MLP, aggregation)

This structure could be made clearer in the current implementation. Potential benefits:

Annoyingly, this would probably break weight mapping from pytorch to haiku.

sirmarcel commented 2 years ago

I think that'd be nice. Wouldn't that just mean changing the map from schnetpack weights to haiku weights slightly? I'm not quite sure how member modules are handled, so I might be getting this wrong.

sirmarcel commented 2 years ago

This would also solve the current awkward situation of having to extract mean and std from the atomwise spec in the schnetkitloader.