gchq / coreax

A library for coreset algorithms, written in Jax for fast execution and GPU support.
Apache License 2.0
25 stars 2 forks source link

[Feature]: sliced score matching for learning score functions. #75

Closed tl82649 closed 1 year ago

tl82649 commented 1 year ago

Is your proposal related to a problem?

For the Stein kernel, we need to supply the score function $\nabla_\mathbf{x} \log f_X(\mathbf{x})$ of a PDF $f_X(\mathbf{x})$. For a finite sample of $n$ data points $(\mathbf{x}_1,\dots,\mathbf{x}_n)$ in $d$ dimensions, the score function needs to be inferred from the sample.

This is currently performed explicitly by fitting a kernel density estimate (KDE) to the data points, and computing the log transform and gradient of the induced density function; prior to use in the Stein kernel. This process involves the $\mathcal{O}(dn^2)$ KDE process, along with a choice of kernel function and hyperparameters (currently radial basis function with median heuristic, which also pointlessly computes the normalising constant). This could also result in model mismatch for complex, high-dimensional data distributions.

Describe the solution you would like*

I propose using score matching to learn $\nabla_\mathbf{x} \log f_X(\mathbf{x})$ directly from $(\mathbf{x}_1,\dots,\mathbf{x}n)$ using a neural network $h\\theta(\mathbf{x})$ parameterised by weights $\theta$. Specifically, I intend to implement sliced score sampling as outlined in Song et al.'s 2019 paper.

Given the non-trivial function approximation properties of neural networks, this should allow for more complex score functions that better match those induced by the data distributions. The hope is that this will improve the coresets output by methods using the Stein kernel. Moreover, when learned with gradient-based optimisation, this should allow us to learn the score function in $\mathcal{O}(dn)$ time.

Describe alternatives you've considered

Score matching originates from Hyvärinen's 2005 paper. However, as Song et al. point out, the Fisher divergence objective function involves computation of a Hessian trace, which is inefficient for large $d$ and not conducive to autodiff.

Sliced score matching uses random vector projection to convert this to a Hessian-vector product (with optional neural network approximation of $\nabla_\mathbf{x} \log f_X(\mathbf{x})$ ), which can be computed efficiently in Jax.

Additional context

No response

Code of Conduct*

pc532627 commented 1 year ago

@tl82649 shall be the developer on this feature, @pc532627 shall be the reviewer & main point of contact for technical discussions.

pc532627 commented 1 year ago

Additional technical points:

I believe adding a noise parameter (as discussed in yang-song link) initially is a great idea and can facilitate future changes as-well as this one. Presumably the general impact of the noise is to combat poor extrapolation by the neural network into areas of space that are not well covered by the set of samples.

pc532627 commented 1 year ago

Regarding the use of JAX jvp or vjp, the performance will depend on the size of matrices and the order in which they are operated on. How are you determining which neural network architecture to use for this approximation of the score function? Do you intend to include capability in the codebase to define the number of layers, activation functions, learning rates, optimisers, ...?

tl82649 commented 1 year ago

Implementation plan

Given that this will involve a separate optimisation process with several steps, I propose to create a new file: ./coreax/score_matching.py. There will be some minor changes to ./coreax/kernel_herding.py and added unit tests.

New files

How the new changes will function

This will be used as an alternative to other score functions in calls Stein kernel herding, e.g. kernel_herding.stein_kernel_herding_block. For example, in ./examples/weighted_herding.py L58, I would replace the use of rbf_grad_log_f_X with

score_function = sliced_score_matching(X, epochs=100)
coreset, Kc, Kbar = stein_kernel_herding_block(X, C, stein_kernel_pc_imq_element, score_function, nu=nu, max_size=1000)

which would invoke the score matching learner, fit it to X and use it in the call to stein_kernel_herding_block.

Unit testing

For unit testing, I propose to use various densities with known analytic score functions, e.g. Gaussian, Gaussian mixture and multivariate versions of each. Test pass/fail should be predicated on acceptable mean-squared error (MSE) or similar over sample points from the true measure.

tl82649 commented 1 year ago

Additional technical points:

I believe adding a noise parameter (as discussed in yang-song link) initially is a great idea and can facilitate future changes as-well as this one. Presumably the general impact of the noise is to combat poor extrapolation by the neural network into areas of space that are not well covered by the set of samples.

That's right: it's to get better learning in areas of low density, but there's a trade-off between coverage and distortion. We could do something like a mixture of noise models, i.e. Eq 7 in Song's blog..

tl82649 commented 1 year ago

Regarding the use of JAX jvp or vjp, the performance will depend on the size of matrices and the order in which they are operated on. How are you determining which neural network architecture to use for this approximation of the score function? Do you intend to include capability in the codebase to define the number of layers, activation functions, learning rates, optimisers, ...?

I'm initially using the same network architecture as this paper. It's a simple 3-layer network with dense linear transforms and softplus activations. Hidden dimension 128 across the layers (bar output, which is $d$). I plan to put network architectures in a new file ./coreax/networks.py. Perhaps future features could include suitable pre-designed network architectures, e.g. conditional on things like data dimensionality, or other data features? We can also make layers, activations, etc. configurable.

I propose to make the learning rate configurable initially. I'm starting with the Adam optimiser, but the optimiser can also be configurable. Network design will use Flax, and optimisation will use Optax. We should also think about learning rate scheduling.

Given that the score is $\mathbf{x} \mapsto \nabla_\mathbf{x} \log f_X(\mathbf{x})$ for $\mathbf{x} \in \mathbb{R}^d$, the Jacobian (of the score, i.e. the Hessian of the log density) will be $d \times d$, so I think the quadratic form

$$\mathbf{v}^\intercal \nabla\mathbf{x} h\theta(\mathbf{x}) \mathbf{v} $$

and subsequent objective function in Sec 3.2 of the paper should be computed with jvp like

@partial(jit, static_argnames=["score_network"])
def sliced_score_matching_loss_element(x, v, score_network):
    s, u = jvp(score_network, (x,), (v,))
    # assumes Gaussian or Rademacher random variables for the norm-squared term
    obj = v @ u + .5 * s @ s
    return obj

and the outer optimisation step, i.e. over the $K$ network parameters $\theta$, will use reverse mode since the Jacobian here will be $1 \times K$:

@jit
def train_step(state, X, V):
    loss = lambda params: sliced_score_matching_loss(lambda x: state.apply_fn({'params': params}, x))(X, V).mean()
    grads = jax.grad(loss)(state.params)
    state = state.apply_gradients(grads=grads)
    return state

where the sliced_score_matching_loss function is a vmapped version of the the element function above:

def sliced_score_matching_loss(score_network):
    inner = vmap(lambda x, v: sliced_score_matching_loss_element(x, v, score_network), (None, 0), 0)
    return vmap(inner, (0, 0), 0)

Then a main training loop like the following to return the learned score function:

def sliced_score_matching(X, rtype="normal", M=1, lr=1e-3, epochs=10, batch_size=64, hidden_dim=128):
    n, d = X.shape
    k1, k2 = random.split(random.PRNGKey(0))
    sn = ScoreNetwork(hidden_dim, d)
    if rtype == "normal":
        V = random.normal(k1, (n, M, d))
    else:
        V = random.rademacher(k1, (n, M, d), dtype=float)
    state = create_train_state(sn, k2, lr, d)
    batch_key = random.PRNGKey(1)
    for _ in tqdm(range(epochs)):
        idx = random.randint(batch_key, (batch_size,), 0, n)
        state = train_step(state, X[idx, :], V[idx, :])
    # return state
    return lambda x: state.apply_fn({'params': state.params}, x)
tl82649 commented 1 year ago

I have now added a first implementation to the feature/score-matching branch. No unit tests yet, nor noise perturbation, but I've added score matching examples to show its use. I'd appreciate a quick, casual review before writing the noise perturbation in.

Once all these are complete, I'll raise a pull request for formal review.

pc532627 commented 1 year ago

Regarding noise - I'd be happy for us to test out a fixed noise model and then consider adding additional noise models as an additional ticket, if we see signs that this does indeed improve coreset quality.

I strongly support defining networks in a separate part of the code (as you suggest) - we could consider automated methods for hyper-parameter tuning in the future too, given a user specified time budget.

The implementation of sliced_score_matching_loss_element looks good at first glance to me, in agreement with objective in section 3.2 of "Sliced Score Matching: A Scalable Approach to Density and Score Estimation"

I see no obvious issues with the noted training steps & loops. As far as I can tell we have no concept of early stopping here, but I would consider that an additional feature consideration for the future, not part of this ticket.

A couple of notes looking through the code:

Functionally, as far as I can tell from an initial review this implementation plan is fine.

tl82649 commented 1 year ago

@pc532627: I've made a final push prior to raising the pull request.

We are making the assumption of Gaussian or Rademacher distributions inside of coreax/score_matching.py. This is noted on line 29. I wonder if we want to be more explicit to users. If we build documentation, they will be able to see this in the relevant pages, however we might want to raise a warning explicitly for absolute clarity to users when the equations are valid and when they are not. Just a thought - we should clarify a stance on this going forward for any similar assumptions we make going forward.

I have now made this agnostic to random variable choice, and instead put the argument use_analytic to coreax.scrore_matching.sliced_score_matching that switches between objective functions where the integral is analytic or not. It would be up to the user to set this flag based on the choice of random variable (rgenerator argument).

I wonder if create_train_state is better placed inside of the neural network part of the code, rather than the score matching file. The same goes for train_step.

Agreed. I've moved create_train_state into coreax.networks. However, the train steps are now tied to the different score matching methods, so I've left them in coreax.score_matching for now.

We've got hard coded strings inside of sliced_score_matching. As a general rule, we've found having a constants module, where such strings are defined can help enormously in making edits when the codebase grows, without introducing additional bugs. As future tickets, I might suggest creating a constants file and calling names and strings from this.

Good idea about hard-coded strings, though I've removed them here in favour of a use_analytic flag.

There is a significant amount of code here, I think this highlights the need to get the style automation in place ASAP and ensure we work to it, as it should make reviews more consistent too.

I think I've applied all style guidelines, but do correct anything you see or point out if I've missed something.

Note to raise another ticket relating to transforming the new examples into an end-to-end test along the lines of what is being done for the existing oens.

It might be worth noting that these new tests take a little longer given the various training steps. Not a huge amount of time (< 1 min on my machine).

From running the example, tqdm is a required dependency but it is not listed in setup.py as far as I can tell. Also, there is an assumed location you run examples from (e.g. root of the directory) - we may want to update this update the file path based on if its run from root or inside of the examples file.

tqdm added to setup.py. There are a few uses of relative paths in the codebase; maybe worth raising a separate ticket to address these consistently?

There are several aspects of style, but these are by no means urgent.

Hopefully now covered by me applying the style guidelines.

I think we currently have examples/pounce/pounce.gif as-well as examples/data/pounce/pounce.gif - if this intentional?

I believe the former is what the README is pointing to, and the latter is what the example scripts output. Intentional yes, to avoid overwriting the README example.

pc532627 commented 1 year ago

Merged into main as part, resolved issue.