IntelLabs / matsciml

Open MatSci ML Toolkit is a framework for prototyping and scaling out deep learning models for materials discovery supporting widely used materials science datasets, and built on top of PyTorch Lightning, the Deep Graph Library, and PyTorch Geometric.
MIT License
144 stars 20 forks source link

Quality of life and helper callback functions #237

Closed laserkelvin closed 3 months ago

laserkelvin commented 4 months ago

This PR introduces and adds a bunch of changes pertaining to informing the user of things happening under the hood, particularly during training.

One of the big philosophical changes is also focusing more on enabling logging to be done with TensorBoardLogger and WandbLoggers by writing functions more tailored to them, rather than before where loggers were treated in the abstract entirely.

Summary

My intention for the TrainingHelperCallback is to be like a guide for best practices: we can refine this as we go and discover new things, and hopefully will be useful for everyone including new users.

laserkelvin commented 4 months ago

I have somehow broken SAM and need to fix it first before review

laserkelvin commented 4 months ago

I think I have a lead on what the issue is: because of how SAM works, and because of the modifications to "stashing" embeddings in the batch structure, we now end up with two disjoint computational graphs that causes backward to break.

This needs a bit of thought to fix...

laserkelvin commented 4 months ago

Confirming this by changing out the BaseTaskModule.forward:

        if "embeddings" in batch:
            embeddings = batch.get("embeddings")
        else:
            embeddings = self.encoder(batch)
            batch["embeddings"] = embeddings
        outputs = self.process_embedding(embeddings)
        return outputs

Removing the branch, and just running the encoder + processing embeddings works (i.e. don't try and grab cached embeddings).

Ideally there would be a way to check if embeddings originated from the same computational graph, but that take a lot more surgery than this PR warrants. I'll think of an alternative to this.

The reason we are stashing the embeddings is to benefit the multitask case, where we would want to not have to run the encoder X times for X tasks and datasets.