dgcnz / relaxed-equivariance-dynamics

Code for "Effect of equivariance on training dynamics"
2 stars 0 forks source link

[Meta Issue] Wang 2022 Extension: Cross evaluation by changing model equivariance (alpha) and dataset equivariance #61

Closed dgcnz closed 5 months ago

dgcnz commented 5 months ago

Description

This issue concerns the extension of Wang 2022 figure 4 which consists on fixing a dataset's equivariance (levels = [full, partial]) and training a relaxed equivariant model with k different alpha.

So we have a total of 2 * k runs per model, where 2 represents the number of datasets and k represents the number of different alpha tested.

Configs

SmokePlume configs:

Model default configs (DO NOT MODIFY THESE FILES, other experiment files rely on the defaults set here):

Experiment configs:

These configs have to be at least tested with trainer.fast_dev_run to ensure that the model even processes data correctly. This doesn't account for model checkpointing and early stopping, so we'll have to add tests to that. Examples can be found in the Makefile's command test_wang2022_figure_4 which you can run make test_wang2022_figure_4.

Example testing command:

python -m src.train experiment=wang2022/equivariance_test/rgroup +trainer.fast_dev_run=True data.batch_size=8

Legend:

Tasks

Questions

Feel free to add more questions or tasks

dgcnz commented 5 months ago

Example usage of downloading and loading from wandb.

import wandb
from src.utils.wandb import get_all_checkpoints, download_artifact
from src.models.wang2022_module import Wang2022LightningModule
from pathlib import Path

# get all checkpoints from run id
entity = "uva-dl2"
project = "wang2022"
run_id = "bpiojh4w"

checkpoints = get_all_checkpoints(run_id, project, entity)
print(checkpoints)
with wandb.init(project=project, entity=entity, job_type="run-evaluation-test") as run:
    artifact_dir = download_artifact(run, checkpoints[0], project, entity)
    model = Wang2022LightningModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
    print(model)
dgcnz commented 5 months ago

done