secondmind-labs / trieste

A Bayesian optimization toolbox built on TensorFlow
Apache License 2.0
219 stars 42 forks source link

Investigate improvements to the get-unique-points algorithm #781

Closed khurram-ghani closed 8 months ago

khurram-ghani commented 1 year ago

Describe the feature you'd like Investigate whether this algorithm can be parallelised. There is a previous parallel but incomplete implementation and the current sequential one. See this PR for context.

j-wilson commented 10 months ago

This can be done in O(N + M^2) time and space by using a bloom filter, where N is the number of points and M is the number of duplicate points. Roughly:

  1. Quantized each point to the desired precision.
  2. Convert each quantized point to a string type.
  3. Hash each stringified+quantized point.
  4. Test for equality within each bin where a collision occured.

Here is a naive implementation with O(N^2) time and space complexity that computes all pairwise distances and returns the indices of each unique row

@tf.function
def get_unique_rows_dense(matrix: tf.Tensor, precision: Optional[int] = None) -> tf.Tensor:
    matrix = (
        matrix
        if precision is None
        else tf.math.round(10 ** precision * matrix)
    )
    sq_norms = tf.reduce_sum(tf.square(matrix), axis=-1, keepdims=True)
    sq_dists = (
        sq_norms 
        + tf.transpose(sq_norms) 
        - tf.matmul(2 * matrix, matrix, transpose_b=True)
    )
    argmin = tf.argmin(sq_dists, axis=-1) 
    unique = tf.where(argmin == tf.range(tf.shape(matrix)[0], dtype=argmin.dtype))
    return tf.squeeze(unique, axis=-1)

And, here is a fancy implementation using the approach suggested above

@tf.function
def get_unique_rows(
    matrix: tf.Tensor, precision: Optional[int] = None, **kwargs: Any,
) -> tf.Tensor:
    matrix = (
        matrix
        if precision is None
        else tf.math.round(10 ** precision * matrix)
    )
    strings = tf.strings.reduce_join(tf.strings.as_string(matrix), axis=-1)
    bin_ids = tf.strings.to_hash_bucket_fast(strings, 2 ** 63 - 1)
    unique_ids, membership = tf.unique(bin_ids, out_idx=tf.int64)

    nrows = tf.shape(matrix)[0]
    nbins = tf.size(unique_ids)  # number of occupied bins
    if nbins == nrows:
        return tf.range(nrows, dtype=tf.int64)

    def deduplicate(k):
        indices = tf.squeeze(tf.where(membership == k), axis=-1)
        rows = tf.gather(matrix, indices)
        return tf.gather(indices, get_unique_rows_dense(rows))

    bins = tf.range(nbins, dtype=tf.int64)
    bin_counts = tf.math.bincount(tf.cast(membership, tf.int32))
    collisions = bin_counts > 1

    # Handle single occupancy bins
    singletons = tf.where(tf.reduce_any(membership[:, None] == bins[~collisions], -1))

    # Resolve collisions
    deduplications = tf.map_fn(deduplicate, bins[collisions], **kwargs)

    return tf.squeeze(tf.concat([singletons, deduplications], axis=0), axis=-1)

Running on my laptop with X = tf.random.uniform(shape=[16384, 4], dtype=tf.float64) gives

%timeit get_unique_rows_dense(X, precision=1)
%timeit get_unique_rows(X, precision=1)
728 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
177 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and

%timeit get_unique_rows_dense(X, precision=None)
%timeit get_unique_rows(X, precision=None)
713 ms ± 9.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
20.7 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

In the latter case, the additional speedup occurs because no points hash to the same bin (hence nbins == nrows).