aimalz / justice

A semi-unsupervised classifier for noisy, irregular, sparse transient and variable astronomical lightcurves
MIT License
1 stars 2 forks source link

Dataset for random negative generation #78

Open gatoatigrado opened 5 years ago

gatoatigrado commented 5 years ago

The TF learned alignment model will output a vector for each point in a lightcurve,

image

We'd like to train a model to generate these vectors, such that the correct alignment produces high dot products between correctly aligned samples.

A synthetic dataset for training this model should generate negatives and positives.

For this, we should use tf.data.Dataset, a modern TensorFlow API for representing a stream of input data. I think this link has a decent introduction to the API with examples: http://adventuresinmachinelearning.com/tensorflow-dataset-tutorial/. At the end we'd like an input_fn() which returns a tf.data.Dataset instance, probably with output_shapes

{
    'left.band_r.before_flux': tf.TensorShape([window_size]),
    ...
    'right.band_r.before_flux': tf.TensorShape([window_size]),
    'label': tf.TensorShape()
}

and output_types

{
    'left.band_r.before_flux': tf.float32,
    ...
    'right.band_r.before_flux': tf.float32,
    'label': tf.int32
}

where 'label' is 0 for negatives and 1 for positives.

https://github.com/aimalz/justice/pull/76 has an example of generating a dataset for all points in a light curve. Here, we want to sample instead of feeding all points at once, and also prefix tensors with left/right. For the negatives, we may be able to do some of this with dataset combinators,

left : tf.data.Dataset = get_random_points_dataset()
right : tf.data.Dataset = get_random_points_dataset()

def prefix_mapper(item):
    result = {f"left.{key}": tensor for key, tensor in item[0]}
    result.update({f"right.{key}": tensor for key, tensor in item[1]}
    result["label"] = 0  # this is a random negative
    return result
negatives = tf.data.Dataset.zip(left, right).map(prefix_mapper)
TedSinger commented 5 years ago
    bcolz_source = plasticc_data.PlasticcBcolzSource.get_default()
    meta_table = bcolz_source.get_table('training_set_meta')
TedSinger commented 5 years ago
def gen_ids(meta_table):
    meta_map = meta_table.where('object_id > 0', outcols=['object_id'])
    for row in meta_map:
        yield row.object_id
gatoatigrado commented 5 years ago

another option:

all_ids = meta_table['object_id'][:]