som-shahlab / ehr_ml

Code for doing machine learning with various EHRs
MIT License
21 stars 3 forks source link

Create a more standard training loop interface for pretraining #8

Open jason-fries opened 3 years ago

jason-fries commented 3 years ago

Currently, clmbr_train_model conceals more familiar training loops structure from users. In most demos and APIs, the boilerplate looks like what's outlined here https://github.com/PyTorchLightning/pytorch-lightning with this structure

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

basically the form

Specific details around the loss are configured in the model architecture and the trainer class handles stuff like progress bars, choice of optimizer, etc.

What is the lift required to provide a demo and refactor to support this type of workflow?

woffett commented 3 years ago

The refactor PR puts pre-training into this kind of API:

model = CLMBRFeaturizer(config, info)
dataset = PatientTimelineDataset(extract_path, ontology_path, info_path)
model.fit(dataset)

The original clmbr_train_model still works, just uses this API. I'll leave this issue open until piton_private is updated to reflect these changes.