tzaeschke / phtree-cpp

PH-Tree C++ implementation
Apache License 2.0
29 stars 9 forks source link

Can I find knn-search results after applied a filter? #151

Closed rockingdice closed 1 year ago

rockingdice commented 1 year ago

My use case is to find the nearest target that matches certain conditions: Like an enemy unit is within the cannon's firing range. But the firing range has a minimum limit, so I want to exclude the enemies that are too close. I tried to use begin_knn_query to query for the nearest enemy target, but it will return empty if it's filtered by distance. Is there any way to achieve my goal, or do I have to find all targets in the firing range first, then do the filter manually?

tzaeschke commented 1 year ago

Your approach should work. Use begin_knn_query with some filter that excludes neighbors that are too close. Can you post the code that you tried? Maybe even a dataset or small example?

However: nearest neighbor queries are usually quite a bit slower than window queries or sphere queries. Depending on how many units are with range, it may be faster to do a sphere/range query and than check all of them in order to find the desired target.

rockingdice commented 1 year ago

Thanks for your quick reply! It's good to know the efficiency difference between those queries. I've done some tests and tracked down a bit more, in class IteratorKnnHS:


  private:
    void FindNextElement() {
        while (remaining_ > 0 && !(queue_n_.empty() && queue_v_.empty())) {
            bool use_v = !queue_v_.empty();
            if (use_v && !queue_n_.empty()) {
                use_v = queue_v_.top().first <= queue_n_.top().first;
            }
            if (use_v) {
                // data entry
                auto& result = queue_v_.top();
                --remaining_;
                this->SetCurrentResult(result.second);
                current_distance_ = result.first;
                queue_v_.pop();
                return;
            } else {
                // inner node
                auto top = queue_n_.top();
                auto& node = top.second->GetNode();
                auto d_node = top.first;
                queue_n_.pop();

                if (d_node > max_node_dist_ && queue_v_.size() >= remaining_) {
                    // ignore this node
                    continue;
                }

                for (auto& entry : node.Entries()) {
                    const auto& e2 = entry.second;
                    if (this->ApplyFilter(e2)) {                 <<<<<<< filter applied here

The queue_n_ contains only 1 element, which is the nearest unit that does not satisfy the filter. The iteration is SetFinished() after the loop. So the result set is empty.

I used a custom filter that filters the elements between a radius range:

template <typename CONVERTER, typename DISTANCE>
class FilterRing {
    using KeyExternal = typename CONVERTER::KeyExternal;
    using KeyInternal = typename CONVERTER::KeyInternal;
    using ScalarInternal = typename CONVERTER::ScalarInternal;
    static constexpr auto DIM = CONVERTER::DimInternal;

  public:
    template <typename DIST = DistanceEuclidean<CONVERTER::DimExternal>>
    FilterRing(
        const KeyExternal& center,
        const double radius_min,
        const double radius_max,
        const CONVERTER& converter,
        DIST&& distance_function = DIST())
    : center_external_{center}
    , center_internal_{converter.pre(center)}
    , radius_min_{radius_min}
    , radius_max_{radius_max}
    , converter_{converter}
    , distance_function_(std::forward<DIST>(distance_function)){};

    template <typename T>
    [[nodiscard]] bool IsEntryValid(const KeyInternal& key, const T&) const {
        KeyExternal point = converter_.get().post(key);
        auto dist = distance_function_(center_external_, point);
        return  dist <= radius_max_ && dist >= radius_min_;
    }

    /*
     * Calculate whether AABB encompassing all possible points in the node intersects with the
     * sphere.
     */
    [[nodiscard]] bool IsNodeValid(const KeyInternal& prefix, std::uint32_t bits_to_ignore) const {
        // we always want to traverse the root node (bits_to_ignore == 64)

        if (bits_to_ignore >= (detail::MAX_BIT_WIDTH<ScalarInternal> - 1)) {
            return true;
        }

        ScalarInternal node_min_bits = detail::MAX_MASK<ScalarInternal> << bits_to_ignore;
        ScalarInternal node_max_bits = ~node_min_bits;

        KeyInternal closest_in_bounds;
        for (dimension_t i = 0; i < DIM; ++i) {
            // calculate lower and upper bound for dimension for given node
            ScalarInternal lo = prefix[i] & node_min_bits;
            ScalarInternal hi = prefix[i] | node_max_bits;

            // choose value closest to center for dimension
            closest_in_bounds[i] = std::clamp(center_internal_[i], lo, hi);
        }

        KeyExternal closest_point = converter_.get().post(closest_in_bounds);
        auto dist = distance_function_(center_external_, closest_point);
        return dist <= radius_max_ && dist >= radius_min_;
    }

  private:
    KeyExternal center_external_;
    KeyInternal center_internal_;
    double radius_min_;
    double radius_max_;
    std::reference_wrapper<const CONVERTER> converter_;
    DISTANCE distance_function_;
};

template <typename CONVERTER, typename DISTANCE>
class FilterMultiMapRing : public FilterRing<CONVERTER, DISTANCE> {
    using Key = typename CONVERTER::KeyExternal;
    using KeyInternal = typename CONVERTER::KeyInternal;

  public:
    template <typename DIST = DistanceEuclidean<CONVERTER::DimExternal>>
    FilterMultiMapRing(
        const Key& center, double radius_min, double radius_max, const CONVERTER& converter, DIST&& dist_fn = DIST())
    : FilterRing<CONVERTER, DIST>(center, radius_min, radius_max, converter, std::forward<DIST>(dist_fn)){};

    template <typename ValueT>
    [[nodiscard]] inline bool IsBucketEntryValid(const KeyInternal&, const ValueT&) const noexcept {
        return true;
    }
};

You can test by putting some points near the search point and some in the valid range. The results are zero if you try to query 1 knn result. The expected result is the nearest target that is within the valid range.

I don't think there's a simple way to solve the problem - since the filter affects the results and you cannot precalculate a valid maximum count of results before the filtering. But I still want to know your insight about it :)

I may try to query for all the targets for now and look for the nearest target by later logic, it won't bring too much work and still should be as efficient as it can.

rockingdice commented 1 year ago

Oh, I just realized maybe it's the problem with my filter. How can I test if the radius is less than the min radius?

tzaeschke commented 1 year ago

I would suggest to have the IsNodeValid function simply return true. The function is complicated to write and I think it only helps if there are probably 100-1000 or more entities inside the minimum radius. In fact, the function may be too expensive and effectively hamper performance.

I also agree and think your IsNodeValid function is not correct.

First, the last line should use || instead of && (on second thought, it may not really matter). Second, I think it is wrong because it returns true only if the closest point (e.g. closest corner) in the node is larger than radius_min. Instead you should exclude a node only if it's furthest point (e.g. corner) is inside the minimum radius. E.g.

ScalarInternal c = center_internal_[i];
closest_in_bounds[i] = ci < lo ? lo : ci > hi ? hi : ci; // just for comparison, you can use clamp() here
farthest_in_bounds[i] = ci < lo ? hi : ci > hi ? lo : x; // (TODO calculate `x` to equal lo/ho whether lo or hi are further way);

Unfortunately, the above does NOT work because you cannot calculate a distance from the internal representation. Instead you will have to convert lo/hi to external coordinates first.

E.g. something like this:

    bool IsNodeValid(const KeyInternal& prefix, std::uint32_t bits_to_ignore) const {
        // we always want to traverse the root node (bits_to_ignore == 64)

        if (bits_to_ignore >= (detail::MAX_BIT_WIDTH<ScalarInternal> - 1)) {
            return true;
        }

        ScalarInternal node_min_bits = detail::MAX_MASK<ScalarInternal> << bits_to_ignore;
        ScalarInternal node_max_bits = ~node_min_bits;

        KeyInternal lo_in;
        KeyInternal hi_in;
        for (dimension_t i = 0; i < DIM; ++i) {
            // calculate lower and upper bound for dimension for given node
            lo_in[i] = prefix[i] & node_min_bits;
            hi_in[i] = prefix[i] | node_max_bits;
        }

        KeyExternal lo_ex = converter_.get().post(lo_in);
        KeyExternal hi_ex = converter_.get().post(hi_in);
        KeyExternal closest_point;
        KeyExternal farthest_point;
        for (dimension_t i = 0; i < DIM; ++i) {
            // choose value closest to center for each dimension
            ScalarExternal lo = lo_ex[i];
            ScalarExternal hi = hi_ex[i];
            ScalarExternal ci = center_external_[i];
            closest_point[i] = ci < lo ? lo : ci > hi ? hi : ci; // just for comparison, you still can use clamp() here
            farthest_point[i] = ci > (hi - lo)/2 ? lo : hi; // farthest axis aligned distance to either lo or hi
        }

        auto dist_min = distance_function_(center_external_, closest_point);
        auto dist_max = distance_function_(center_external_, farthest_point);
        return dist_min <= radius_max_ || dist_max >= radius_min_;
    }

Disclaimer, I did not test this code, so please take it with a grain of salt :-)

rockingdice commented 1 year ago

@tzaeschke wow! Thanks for your explanation. I'll try it later, for now, I'll just close this issue because I have no further questions. Great work on this library my friend!