Open dkirkby opened 4 years ago
Interesting! Thanks for the entry! Looking forward to see the results :-)
And to answer one question from one of your notebooks:
Why are the buzzard scores significantly lower than the DC2 ones?? They have pretty different redshift distributions, dc2 has a higher mean redshift, I think that's mostly what drives the difference in metrics
More details on the first optimization step to group similar bins in multidimensional feature space, which is independent of any choice of metric:
Initial bins are defined as a rectangular grid in feature space with a grid
defined such that the projections onto each feature axis contain an equal
number of samples. Any bins that contain no galaxies form one group and
are removed from subsequent analysis.
Similarity is defined as the product of independent feature and redshift similarities,
with values in the range 0-1 and 1 indicating maximum similarity.
Feature similarity is defined as exp(-(dr / sigma) ** 2) where dr is the Euclidean
(grid index) separation in the multimensional feature space rectangular grid.
The hyperparameter sigma controls the relative importance of the feature and
redshift similarities in the subsequent grouping. Values of dr are normalized
such that dr=1 corresonds to the full range of each feature, i.e., the grid
size along each feature axis.
Redshift similarity is based on the histogram of redshifts associated with each
feature bin, interpreted as a vector. When weighted is False, similarity is
calculated as the cosine of the angle between two vectors. Since all components
of the redshift vector are histogram bin contents >= 0, the resulting cosine
simililarities are all in the range 0-1
Since cosine similarity uses normalized vectors, it does not give more weight to
a feature bin containing more samples. An alternative, when weighted is True,
is to use a similarity score of |z1 + z2| / (|z1| + |z2|), which is equivalent
to a weighted average of the cosine similarities between z1+z2 and z1 or z2,
respectively, with weights wk = |zk| / (|z1| + |z2|).
Feature bins are grouped iteratively, by combining the pair of groups with the
maximum similiarity, until either a minimum number of groups is reached or else
all remaining groups are above a minimum sample threshold. There is also a
maximum sample threshold, and pairs whose combination would exceed this
threshold are never combined.
For testing purposes, grouping can also be terminated after a fixed number
of iterations. The incremental updates of the feature and redshift similarity
matrices can also be periodically validated against much slower calculations
from scratch.
When two groups are merged, their redshift histograms are added to calculate
updated feature similarities with all other remaining groups. The updated
feature similarities use the minimum separation dr between the bins of
the merged group and bins of each other group. This ensures that the maximum
feature similarity occurs between adjacent groups, regardless of their size.
The second optimization step is simply a direct maximization of the chosen metric with respect to a weight matrix of shape (nbin, ngrp) with ngrp ~ 200. The metric is calculated using a fast reweighting scheme implemented in jax and optimized using jax gradients.
Preliminary 3x2 DETF FoM scores using the Buzzard dataset restricted to riz:
nbin | DETF 3x2 |
---|---|
2 | 20.5 |
4 | 55.6 |
6 | 69.6 |
8 | 78.7 |
An example of optimized bins for n=4:
There are 2 classifiers implemented here, called ZotBin and ZotNet, with example yaml files:
python bin/challenge.py example/zotbin_dc2.yaml
python bin/challenge.py example/zotnet_buzzard.yaml
The graphs below show the results when optimizing each of the 3x2 metrics calculated by jax-cosmo. First, for Buzzard riz data:
Similarly for DC2 riz data:
@dkirkby - thanks so much for your entries!
I'm currently putting together the environment to run all these, and I this error:
from zotbin.nnet import learn_nnet, apply_nnet
ImportError: cannot import name 'apply_nnet'
I installed the head of the master branch of zotbin - is there a different version I should install?
@joezuntz Sorry I missed this earlier. I just pushed a commit to fix this, which simply removes apply_net
from the import:
from zotbin.nnet import learn_nnet
Thanks!
When trying to run this method on our cluster GPU (12 GB TITAN V) I consistently get this error:
Traceback (most recent call last):
File "bin/run_one.py", line 62, in <module>
scores = run_one(name, bands, settings, training_data, training_z, validation_data,
File "/home/jzuntz/tomo_challenge/bin/challenge.py", line 121, in run_one
C.train(train_data,train_z)
File "/home/jzuntz/tomo_challenge/bin/../tomo_challenge/classifiers/zotbin.py", line 92, in train
U = self.preprocessor(features)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/zotbin/flow.py", line 110, in flow_map
Y_normal, _ = bijection_direct(final_params, jnp.array(Y_preproc))
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/bijections.py", line 460, in direct_fun
return feed_forward(params, direct_funs, inputs)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/bijections.py", line 455, in feed_forward
inputs, log_det_jacobian = apply_fun(param, inputs, **kwargs)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/made.py", line 38, in direct_fun
log_weight, bias = apply_fun(params, inputs).split(2, axis=1)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/experimental/stax.py", line 302, in apply_fun
inputs = fun(param, inputs, rng=rng, **kwargs)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/flows/bijections/made.py", line 18, in apply_fun
return np.dot(inputs, W * mask) + b
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 2924, in dot
return lax.dot(a, b, precision=precision)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/lax/lax.py", line 595, in dot
return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())),
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/lax/lax.py", line 629, in dot_general
return dot_general_p.bind(lhs, rhs,
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/core.py", line 274, in bind
return self.impl(*args, **kwargs)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/interpreters/xla.py", line 264, in xla_primitive_callable
compiled = backend_compile(backend, built_c, options)
File "/home/jzuntz/.conda/envs/tomo/lib/python3.8/site-packages/jax/interpreters/xla.py", line 325, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: Resource exhausted: Out of memory while trying to allocate 3300966784 bytes.
Any ideas welcome. I have tried various mem allocation things.
I ran tests successfully on an 11Gb RTX 2080 GPU but perhaps you are running with more input features than I tested. I did all my testing with riz only.
You could either:
There seems to be an issue when importing both tensorflow and jax at the same time as both claim all the GPU memory. I was able to fix it with some environment variables.
Yes, this is a known issue. There are several workarounds suggested here, in case you didn't already find this.
A tomographic binning method from the UC Irvine team (zot?)
The basic idea is to perform two stages of optimization:
Divide the feature space into a large number, O(10K), of rectangular cells then iteratively combine cells into O(100) groups according their joint feature and redshift distribution similarities.
Combine the O(100) groups into O(1-10) final bins by optimizing a metric such as FOM_3x2 or FOM_DETF_3x2.
An example showing the redshift distributions of 200 groups partitioning the Buzzard riz feature space, obtained after step 1:
Uses code from a separate repo that can be installed with:
More details and plots to follow...