lucmos / relreps

Relative representations can be leveraged to enable solving tasks regarding "latent communication": from zero-shot model stitching to latent space comparison between diverse settings.
https://openreview.net/forum?id=SrC-nwieGJ
MIT License
45 stars 5 forks source link

AE latent space training code #22

Open dribnet opened 4 months ago

dribnet commented 4 months ago

I am trying to reproduce figure 1 from the paper:

fig1

I've found the code infig:latent-rotation/visualize.ipynb and am attempting to get it to work. IIUC - it appears to assume that pre-baked model checkpoints are downloaded with the down.sh helper script.

If I've got this right - could you provide a pointer to the AE training of these MNIST checkpoints themselves? My interest is in recreating your results including model training so that I can do follow up experiments which vary the upstream parameters on the latent spaces.

dribnet commented 4 months ago

I've managed to get thegetckpts download utility to run, but it appears the referenced checkpoints are no longer available on wandb (or perhaps I don't have permissions to access)? For all checkpoints, I get errors such as the following:

│ ╭──────────────────────────── locals ────────────────────────────╮                   │
│ │  entity = 'gladia'                                             │                   │
│ │    path = 'gladia/rae/1fdme9ar'                                │                   │
│ │ project = 'rae'                                                │                   │
│ │  run_id = '1fdme9ar'                                           │                   │
│ │    self = <wandb.apis.public.api.Api object at 0x7ffac1a7b0d0> │                   │
│ ╰────────────────────────────────────────────────────────────────╯                   │

CommError: Could not find run <Run gladia/rae/1fdme9ar (not found)>                     

If this is the case, it would be helpful to understand how to best recreate these four 2D MNIST AE spaces in order to faithfully replicate these results.


update: I found the old multirun_machine2.sh dev script which appears to be the AE training routine. Here's that script:

#!/bin/bash
# Reconstruction

python src/rae/run.py -m \
  core.tags='[classification, absolute, fig:ae-rotations, small_cnn]' \
  'nn/data/datasets=vision/mnist' \
  'train.seed_index=0,1,2,3,4,5' \
  nn/module=classifier \
  nn/module/model=cnn \
  train=classification \
  nn.module.model.latent_dim=2 \
  nn.data.anchors_num=500 \
  "nn.module.model.hidden_dims=[2, 3, 4, 8]" \
  "nn.module.optimizer.lr=5e-4" \
  train.trainer.max_epochs=40

After updating the entity in conf/train/classification.yaml to match my wandb login, I was able to get this script to run. However it then bombed out for me with what appears to be a missing conf key:

omegaconf.errors.InterpolationKeyError: Interpolation key 'nn.data.datasets.val_fixed_sample_idxs' not found

So pleased to find the AE training code was present all along, but currently stuck on updating the configurations to get it to launch successfully.

Flegyas commented 3 months ago

Hi @dribnet, thank you for your interest, and I apologize for the delay!

To reproduce the experiments, all the files are uploaded to Google Drive using DVC, so you shouldn't need to retrain or run anything to reproduce any of them. You only need to have the dvc and dvc-gdrive dependencies installed via pip.

In this case, the notebook assumes to have the files inside the "checkpoints" folder (indexed by the checkpoints.dvc file). Therefore, you can pull just that by running dvc pull checkpoints from within the fig:latent-rotation folder.

If you want to download all the files for the project, you can run dvc pull from the root folder.

Let us know if everything works fine (the Google Drive linked to DVC should be correctly shared in read-only mode).