alan-turing-institute / learning-machines-drift

A Python package for monitoring dataset drift in secure environments
4 stars 2 forks source link

Sketch out API for drift package #2

Open OscartGiles opened 2 years ago

OscartGiles commented 2 years ago

API design

This is a very rough first sketch of what the LM drift detection might look like.

The key feature of the drift detection library is it will compare datasets seen in production to a reference dataset (normally the dataset used to train the production model).

There are a few things to keep in mind:

An example of using the package

A simple API design based on whylogs might look like this. This first example assumes we have a trained classifier we are using in production.

with DriftDetector(tag="test_tag", expect_features = True, expect_labels = True, expect_latent = False) as detector:

    # Normally X and Y would come from a model fit. Here we use a sample dataset
    X, Y = datasets.logistic_model()

    detector.log_features(X)
    detector.log_labels(Y)

Let's assume this stores the features and labels in some persistent storage (e.g. disk, database). The tag argument to DriftDetector is a unique tag for the model being monitored. The arguments expect_{features | labels | latent} are optional, if set to True an exception will be raised if the features | labels | latent hasn't been logged by the time the context manager exits.

We must register a reference dataset to compare drift against. This is normally the dataset used to train the model. This can be registered before model inference, in which case we could extend our output above to provide information on drift at model inference time:


register_reference(tag = "prod_model", features = X_ref, labels = Y_ref)

with DriftDetector(tag="prod_model", expect_features = True, expect_labels = True, expect_latent = False) as detector:

    # Normally X and Y would come from a model fit. Here we use a sample dataset
    X, Y = datasets.logistic_model()

    detector.log_features(X)
    detector.log_labels(Y)

    # Write drift detection metrics to stdout
    detector.summary()

or we could register it afterwards


register_reference(tag = "prod_model", features = X_ref, labels = Y_ref)

with DriftDetector(tag="prod_model",) as detector:
     # Write drift detection metrics to stdout
    detector.summary()

Summary

We can load the data associated with a tag, ensuring the reference dataset is loaded.

with DriftDetector(tag="prod_model", expect_reference = True) as detector:
     # Write drift detection metrics to stdout
    detector.summary()

Rather than just writing to stdout we probably want to return in a data structure (e.g. Dictionary or data class).

OscartGiles commented 2 years ago

Ok so I'm thinking we might want to start by registering a dataset:

# Given we have a reference dataset
    X_reference, Y_reference = datasets.logistic_model()

    # When we register the dataset
    detector = DriftDetector(tag="test")
    detector.register_ref_dataset(features=X_reference, labels=Y_reference)

and once that is done we might want to ask some basic things about our dataset:

summary = detector.ref_summary()

The question is what things might we want to know:

  1. Probably some stuff about the datashape (N features, N rows, N labels)
  2. dtypes?
  3. Statistics

Let's start with 1. What data structure are we going to return this in (not a dictionary)?

Maybe a dataclass?

class BaselineSummary:
    shapes: DataShapes
    ...

class DataShapes:
    features: FeatureSummary
    labels: LabelSummary

class FeatureSummary:
    n_rows: int
    n_features: int

class LabelSummary:
    n_rows: int
    n_labels: int
OscartGiles commented 2 years ago

Some more spec

Registering a reference dataset

Given:

Logging data

Logging data to monitor for drift might look something like this

with DriftDetector(
        tag="test", expect_features=True, expect_labels=True, expect_latent=True
    ) as detector:

        detector.log_features(X)
        detector.log_labels(Y_pred)
        detector.log_latent(latent_x)

Of interest are the expect_feature, expect_labels and expect_latent arguments to the DriftDetector constructor. These say that we should raise an exception if they are not logged by the time the context manger exits.

We should also only call some of the methods like log_features, log_labels, log_latent once. This could raise an exception, but perhaps this should be a warning -- or at least a parameter to the class constructor (i.e. raise_when_called_twice`.

Drift statistics

OscartGiles commented 2 years ago

How to filter data

How to split out methods

  1. In the first instance assume we have enough data to measure drift
  2. In future iterations consider single data point anomaly detection, interpretability.

How to represent data in memory

Types of timestamp

For first iteration, assume the data is logged as soon as inference is made. In future iteration we could add ability to log historic data with batch identifier (could be timestamp, could be something else).

OscartGiles commented 2 years ago

What to do when logging data if the structure of the data changes

For example, drops a column, column renamed, type changes.

SHould we: