blei-lab / treeffuser

Treeffuser is an easy-to-use package for probabilistic prediction on tabular data with tree-based diffusion models.
https://blei-lab.github.io/treeffuser/
MIT License
14 stars 3 forks source link

Wrapper class for samples #91

Closed aagrande closed 3 months ago

aagrande commented 3 months ago

This adds a simple wrapper class for the main output of Treeffuser: the samples.

Subclassing and custom methods

EDIT: see update below by @ANazaret . Samples subclasses ndarray. Samples:

See NumPy's documentation on subclassing for further reference.

Usage

The goal is to facilitate the user experience when computing estimates from Treeffuser samples, such as mean, quantiles, correlation matrices, modes, etc.

samples = model.sample(x, n_samples=100)

samples = Samples(samples)
samples.sample_mean() # conditional mean
samples.sample_quantile(q=[0.05, 0.95])  # conditional quantiles
samples.sample_kde(bandwidth="scott) # conditional KDEs
...
ANazaret commented 3 months ago

Thanks @aagrande ! For the record, we chose to not define Samples as a subclass of np.array because the numpy inherited functions were enabling the user to transform a Samples object into a corrupted state: example taking the transpose was returning a non-sensical Samples object.