LSSTDESC / tomo_challenge

2020 Tomographic binning challenge
13 stars 18 forks source link

ZotBin #34

Open dkirkby opened 4 years ago

dkirkby commented 4 years ago

A tomographic binning method from the UC Irvine team (zot?)

The basic idea is to perform two stages of optimization:

  1. Divide the feature space into a large number, O(10K), of rectangular cells then iteratively combine cells into O(100) groups according their joint feature and redshift distribution similarities.

  2. Combine the O(100) groups into O(1-10) final bins by optimizing a metric such as FOM_3x2 or FOM_DETF_3x2.

An example showing the redshift distributions of 200 groups partitioning the Buzzard riz feature space, obtained after step 1: image

Uses code from a separate repo that can be installed with:

pip install git+https://github.com/dkirkby/zotbin.git

More details and plots to follow...

EiffL commented 4 years ago

image Interesting! Thanks for the entry! Looking forward to see the results :-)

And to answer one question from one of your notebooks:

Why are the buzzard scores significantly lower than the DC2 ones?? They have pretty different redshift distributions, dc2 has a higher mean redshift, I think that's mostly what drives the difference in metrics

dkirkby commented 4 years ago

More details on the first optimization step to group similar bins in multidimensional feature space, which is independent of any choice of metric:

Initial bins are defined as a rectangular grid in feature space with a grid
defined such that the projections onto each feature axis contain an equal
number of samples.  Any bins that contain no galaxies form one group and
are removed from subsequent analysis.

Similarity is defined as the product of independent feature and redshift similarities,
with values in the range 0-1 and 1 indicating maximum similarity.

Feature similarity is defined as exp(-(dr / sigma) ** 2) where dr is the Euclidean
(grid index) separation in the multimensional feature space rectangular grid.
The hyperparameter sigma controls the relative importance of the feature and
redshift similarities in the subsequent grouping. Values of dr are normalized
such that dr=1 corresonds to the full range of each feature, i.e., the grid
size along each feature axis.

Redshift similarity is based on the histogram of redshifts associated with each
feature bin, interpreted as a vector.  When weighted is False, similarity is
calculated as the cosine of the angle between two vectors.  Since all components
of the redshift vector are histogram bin contents >= 0, the resulting cosine
simililarities are all in the range 0-1

Since cosine similarity uses normalized vectors, it does not give more weight to
a feature bin containing more samples.  An alternative, when weighted is True,
is to use a similarity score of |z1 + z2| / (|z1| + |z2|), which is equivalent
to a weighted average of the cosine similarities between z1+z2 and z1 or z2,
respectively, with weights wk = |zk| / (|z1| + |z2|).

Feature bins are grouped iteratively, by combining the pair of groups with the
maximum similiarity, until either a minimum number of groups is reached or else
all remaining groups are above a minimum sample threshold.  There is also a
maximum sample threshold, and pairs whose combination would exceed this
threshold are never combined.

For testing purposes, grouping can also be terminated after a fixed number
of iterations.  The incremental updates of the feature and redshift similarity
matrices can also be periodically validated against much slower calculations
from scratch.

When two groups are merged, their redshift histograms are added to calculate
updated feature similarities with all other remaining groups.  The updated
feature similarities use the minimum separation dr between the bins of
the merged group and bins of each other group. This ensures that the maximum
feature similarity occurs between adjacent groups, regardless of their size.

The second optimization step is simply a direct maximization of the chosen metric with respect to a weight matrix of shape (nbin, ngrp) with ngrp ~ 200. The metric is calculated using a fast reweighting scheme implemented in jax and optimized using jax gradients.

dkirkby commented 4 years ago

Preliminary 3x2 DETF FoM scores using the Buzzard dataset restricted to riz:

nbin DETF 3x2
2 20.5
4 55.6
6 69.6
8 78.7

An example of optimized bins for n=4: image

dkirkby commented 4 years ago

There are 2 classifiers implemented here, called ZotBin and ZotNet, with example yaml files:

python bin/challenge.py example/zotbin_dc2.yaml
python bin/challenge.py example/zotnet_buzzard.yaml

The graphs below show the results when optimizing each of the 3x2 metrics calculated by jax-cosmo. First, for Buzzard riz data: buzzard

Similarly for DC2 riz data: dc2

joezuntz commented 4 years ago

@dkirkby - thanks so much for your entries!

I'm currently putting together the environment to run all these, and I this error:

from zotbin.nnet import learn_nnet, apply_nnet
ImportError: cannot import name 'apply_nnet'

I installed the head of the master branch of zotbin - is there a different version I should install?

dkirkby commented 4 years ago

@joezuntz Sorry I missed this earlier. I just pushed a commit to fix this, which simply removes apply_net from the import:

  from zotbin.nnet import learn_nnet
joezuntz commented 4 years ago

Thanks!

joezuntz commented 4 years ago

When trying to run this method on our cluster GPU (12 GB TITAN V) I consistently get this error:

Traceback (most recent call last):
  File "bin/run_one.py", line 62, in <module>
    scores = run_one(name, bands, settings, training_data, training_z, validation_data,
  File "/home/jzuntz/tomo_challenge/bin/challenge.py", line 121, in run_one
    C.train(train_data,train_z)
  File "/home/jzuntz/tomo_challenge/bin/../tomo_challenge/classifiers/zotbin.py", line 92, in train
    U = self.preprocessor(features)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/zotbin/flow.py", line 110, in flow_map
    Y_normal, _ = bijection_direct(final_params, jnp.array(Y_preproc))
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/bijections.py", line 460, in direct_fun
    return feed_forward(params, direct_funs, inputs)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/bijections.py", line 455, in feed_forward
    inputs, log_det_jacobian = apply_fun(param, inputs, **kwargs)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/made.py", line 38, in direct_fun
    log_weight, bias = apply_fun(params, inputs).split(2, axis=1)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/experimental/stax.py", line 302, in apply_fun
    inputs = fun(param, inputs, rng=rng, **kwargs)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/made.py", line 18, in apply_fun
    return np.dot(inputs, W * mask) + b
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 2924, in dot
    return lax.dot(a, b, precision=precision)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/lax/lax.py", line 595, in dot
    return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/lax/lax.py", line 629, in dot_general
    return dot_general_p.bind(lhs, rhs,
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/core.py", line 274, in bind
    return self.impl(*args, **kwargs)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/interpreters/xla.py", line 264, in xla_primitive_callable
    compiled = backend_compile(backend, built_c, options)
  File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/interpreters/xla.py", line 325, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: Resource exhausted: Out of memory while trying to allocate 3300966784 bytes.

Any ideas welcome. I have tried various mem allocation things.

dkirkby commented 4 years ago

I ran tests successfully on an 11Gb RTX 2080 GPU but perhaps you are running with more input features than I tested. I did all my testing with riz only.

You could either:

joezuntz commented 4 years ago

There seems to be an issue when importing both tensorflow and jax at the same time as both claim all the GPU memory. I was able to fix it with some environment variables.

dkirkby commented 4 years ago

Yes, this is a known issue. There are several workarounds suggested here, in case you didn't already find this.