seuqaj114 / paig

Code for the paper Physics-as-Inverse-Graphics: Joint Unsupervised Learning of Objects and Physics from Video
MIT License
39 stars 11 forks source link

Physics-as-Inverse-Graphics

This repo contains the code for the paper Physics-as-Inverse-Graphics: Unsupervised Physical Parameter Estimation from Video (https://arxiv.org/abs/1905.11169).

Running experiments

To train run:

PYTHONPATH=. python runners/run_physics.py --task=spring_color --model=PhysicsNet --epochs=500 
--batch_size=100 --save_dir=<experiment_folder> --autoencoder_loss=3.0 --base_lr=3e-4 --anneal_lr=true
--color=true --eval_every_n_epochs=10 --print_interval=100 --debug=false --use_ckpt=false 

This will automatically run on the test set (evaluation with extrapolation range) in the end of training. To run only evaluation on a previously trained model use the extra flags --test_mode and --use_ckpt:

PYTHONPATH=. python runners/run_physics.py --task=spring_color --model=PhysicsNet --epochs=500 
--batch_size=100 --save_dir=<experiment_folder> --autoencoder_loss=3.0 --base_lr=3e-4 
--color=true --eval_every_n_epochs=10 --print_interval=100 --debug=false 
--use_ckpt=true --test_mode=true 

This will use the checkpoint found in <experiment_folder>. To evaluate a checkpoint from a different folder use --ckpt_dir:

PYTHONPATH=. python runners/run_physics.py --task=spring_color --model=PhysicsNet --epochs=500 
--batch_size=100 --save_dir=<experiment_folder> --autoencoder_loss=3.0 --base_lr=3e-4 
--color=true --eval_every_n_epochs=10 --print_interval=100 --debug=false 
--use_ckpt=true --test_mode=true --ckpt_dir=<folder_with_checkpoint>

To keep training a model from a checkpoint, simply use the same as above, but with --test_mode=false. Note that in this case base_lr will be used as the starting learning rate - there is no global learning rate variable saved in the checkpoint - so if you restart training after annealing was applied, be sure to change the base_lr accordingly.

Notes on flags, hyperparameters, and general training behavior:

Tasks

There are currently 5 tasks implemented in this repo: