google-deepmind / distribution_shift_framework

This repository contains the code of the distribution shift framework presented in A Fine-Grained Analysis on Distribution Shift (Wiles et al., 2022).
Apache License 2.0
80 stars 8 forks source link
artificial-intelligence deep-learning distribution-shift jax machine-learning neural-networks

Distribution Shift Framework

This repository contains the code of the distribution shift framework presented in A Fine-Grained Analysis on Distribution Shift (Wiles et al., 2022).

Contents

The framework allows to train models with different training methods on datasets undergoing specific kinds of distribution shift.

Training Methods

Currently the following training methods are supported (by setting the algorithm config option):

Model Architectures

The model config option can be set to one of the following architectures

Datasets

You can train on the following datasets (by setting the dataset_name config option.):

Each dataset has a task (e.g. shape prediction on dSprites, set with the label config option) and a set of properties (e.g. the colour of the shape in dSprites, set with the property_label config option).

Distribution Shift Scenarios

You can evaluate your model on different conditions by varying the distribution of labels and properties in the configs. For each part of the distribution, you then assign a probability of sampling from that part of the distribution.

Additionally you can modify these scenarios with two conditions:

These scenarios can be set through the test_case config option.) with the keywords in parenthesis and an optional modifier separated by a full stop, e.g. lowdata.noise for low data drift with added label noise.

Future Additions

We plan to add additional methods, models and datasets from the paper as well as the raw results from all the experiments.

Usage Instructions

Installing

The following has been tested using Python 3.9.9.

For GPU support with JAX, edit requirements.txt before running run.sh (e.g., use jaxline==0.1.67+cuda111). See JAX's installation instructions for more details.

Execute run.sh to create and activate a virtualenv, install all necessary dependencies and run a test program to ensure that you can import all the modules.

# Run from the parent directory.
sh distribution_shift_framework/run.sh

Running the Code

To train a model, use this virtualenv:

source /tmp/distribution_shift_framework/bin/activate

and then run

python3 -m distribution_shift_framework.classification.experiment \
--jaxline_mode=train \
--config=distribution_shift_framework/classification/config.py

For evaluation run

python3 -m distribution_shift_framework.classification.experiment \
--jaxline_mode=eval \
--config=distribution_shift_framework/classification/config.py

Config Options {#config-options}

Common changes can be done through an options string following the config file. The following options are available:

Multiple options need to be separated by commas. An example would be

python3 -m distribution_shift_framework.classification.experiment \
--jaxline_mode=train \
--config=distribution_shift_framework/classification/config.py:algorithm=SagNet,test_case=lowdata.noise,model=truncatedresnet18,property_label=label_object_hue,label=label_shape,dataset_name=shapes3d

Which would train a truncated ResNet18 with the SagNet algorithm in the low data setting with added label noise on the Shapes3D dataset. Shape is used as the label for classification while object hue is used as the property that the distribution shifts over.

Sweeps

By default the program generates sweeps over multiple hyper-parameters depending on the chosen training method, dataset and distribution shift scenario. The sweep_index option lets you choose which of the configs in the sweep you want to run.

Citing this work

If you use this code (or any derived code) in your work, please cite the accompanying paper:

@inproceedings{wiles2022fine,
  title={A Fine-Grained Analysis on Distribution Shift},
  author={Olivia Wiles and Sven Gowal and Florian Stimberg and Sylvestre-Alvise Rebuffi and Ira Ktena and Krishnamurthy Dj Dvijotham and Ali Taylan Cemgil},
  booktitle={International Conference on Learning Representations},
  year={2022},
  url={https://openreview.net/forum?id=Dl4LetuLdyK}
}

License and Disclaimer

Copyright 2022 DeepMind Technologies Limited.

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the License. You may obtain a copy of the Apache 2.0 license at

https://www.apache.org/licenses/LICENSE-2.0

All non-code materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY License). You may obtain a copy of the CC-BY License at:

https://creativecommons.org/licenses/by/4.0/legalcode

You may not use the non-code portions of this file except in compliance with the CC-BY License.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

This is not an official Google product.