opeltre / topos

Statistics and Topology
MIT License
8 stars 0 forks source link

Continuous domains #39

Open opeltre opened 1 year ago

opeltre commented 1 year ago

Support continuous domains, e.g. $E_i = \mathbb{R}^{n_i}$, by using either polynomial observables on $E_i$ (degree-2 energies yielding Gaussian kernels) or smooth parameterizations from a vector space $\Theta_i$ to $C(E_i)$ ( nn.Module instance).

1. Gaussian mixtures

To handle $k$-modal gaussian mixtures on $V$, the configuration space can be expanded to $kV = V^{\sqcup k} = V \sqcup \dots \sqcup V$.

Defining a quadratic observable $H$ on $kV$ as a collection $(Hi){1 \leq i \leq k}$ of quadratic polynomials on $V$, the associated Gibbs distribution on $kV$ is projected to a Gaussian mixture on $V$:

$$ p(x|i) = \frac {\mathrm{e}^{-H_i(x)}}{Z_i} $$

$$ p(x, i) = \pi_i \: p(x | i) $$

$$ p(x) = \sum_i \pi_i \: \frac {\mathrm{e}^{-H_i(x)}}{Z_i} $$

and where each $H_i(x)$ may be expressed into a canonical form $H_i(x_i) + \frac 1 2 \langle x - x_i | Q_i | x - x_i \rangle$.
In particular, the mixture loglikelihood is contained in the mean-value term $H_i(x_i)$.

$$ -\ln p(x, i) = - \ln \pi_i + \frac 1 2 \langle x - x_i | Q_i | x - x_i \rangle + \ln Z_i $$

Gaussian mixtures have the huge advantage of being easy and explicit to sample from.

2. Neural networks

There is no real problem in defining observables on continuous domains as nn.Module instances. However, the Gibbs state correspondence fails to be explicit when the hamiltonian $H_i : E_i \to \mathbb{R}$ is a generic nn.Module instance.

One may rely on different strategies to sample from $\frac {1}{Z_i} \mathrm{e}^{-H_i}$, without resorting to naive Monte-Carlo methods. Diffusion models, inspired by annealed importance samply, proved very succesful recently.

Strategy

Keep using core classes

The state sheaf for $\mathbb{R}^n$ could be taken to be Field(Domain(n)) without implementation changes. There still remain subtleties for handling products of discrete and continuous domains => view discrete microstates as Fields with torch.long values

Arrow interface

Harmonise the concept of real observable $f : \mathbb{R}^n \to \mathbb{R}$ with the concept of real-valued field $t : \mathbb{Z} / n \mathbb{Z} \to \mathbb{R}$ by inheriting from a common topos.Observable mixin class. This class should do little more than assuming a (batchable) __call__ method is implemented and that src and tgt attributes are defined, while ensuring compatibility with fp.Arrow(A, B) and functorial operations (Field.cofmap, etc. )

Microstates as fields

Describe a (batched) state $x{\tt a} \in E{\tt a}$ as a Field instance with dtype=torch.long, and Field.__call__ as index selection.

A section $(x{\tt a})$ for ${\tt a} \in K$ can then be viewed as a length $|K|$ batch of microstates in the disjoint union $\bigsqcup{{\tt a} \in K} E_{\tt a}$. Evaluating the total energy then amounts to evaluating local potentials individualy before summing batch values.

Global sections (in coordinate format) are projected onto a batch of indices by calling fp.Torus.index on each of the appropriate slices. In order to work with batches, we only need to divide the loop according to hyperedge dimensions.