rust-random / rand

A Rust library for random number generation.
https://crates.io/crates/rand
Other
1.6k stars 423 forks source link

Add methods weight, weights, and total_weight to weighted_index.rs #1420

Closed MichaelOwenDyer closed 3 months ago

MichaelOwenDyer commented 3 months ago

Summary

After facing some difficulties with the implementation in https://github.com/rust-random/rand/pull/1403, this PR is an alternative solution to the same problem. If this PR is merged, the other can be closed.

Motivation

Please read the motivation section in https://github.com/rust-random/rand/pull/1403. This PR aims to solve the problem mentioned there by exposing insights into the current state of the distribution, allowing the end user to calculate new weights based on the current values and then call the existing update_weights method with those.

Details

Three new methods have been added to weighted_index.rs. Here are their signatures:

weights returns a WeightedIndexIter which iterates over all the weights in the distribution:

pub fn weights(&self) -> WeightedIndexIter<'_, X>
    where
        X: for<'a> ::core::ops::SubAssign<&'a X>
            + Clone

weight returns the weight at a specific index:

pub fn weight(&self, index: usize) -> Option<X>
    where 
        X: for<'a> ::core::ops::SubAssign<&'a X>
            + Clone

total_weight returns the sum of all weights:

pub fn total_weight(&self) -> X
    where X: Clone

In my opinion these methods are also not strongly coupled with the current implementation of WeightedIndex, as I cannot imagine a future implementation which would not be able to support these operations.

I have also added tests for the new methods.

Looking forward to your reviews :)

dhardy commented 3 months ago

In my opinion these methods are also not strongly coupled with the current implementation of WeightedIndex, as I cannot imagine a future implementation which would not be able to support these operations.

Ah, this is why you didn't expose cumulative_weights directly? I strongly doubt we'd need to adjust this.

weights returns a Vec of all the weights in the distribution

An alternative worth considering is to return an iterator over weights: the caller can choose to collect or use on-the-fly and can skip as required.

My suggestion:

MichaelOwenDyer commented 3 months ago

I've changed the name of get to weight (I actually only named it get for consistency with WeightedTreeIndex, but I think weight is better too).

Also created a custom Iterator type WeightedIndexIter which is just calls WeightedIndex::weight with an incrementing index. One issue I ran into here: the crate-wide lint level of missing_debug_implementations is set to deny, but deriving Debug for this type refuses to work because apparently <X as uniform::SampleUniform>::Sampler doesn't implement Debug. This is strange, because WeightedIndex itself derives Debug just fine. I've never seen this before; I added a temporary #[allow(missing_debug_implementations)] to fix compilation. @dhardy, maybe you know what the issue is here?

cumulative_weights can just return &[X]. Is it useful though?

I'm really not sure if it's worth it to expose that. My problem is solved without it, and it would be the only implementation added in this PR which returns a reference type, meaning WeightedIndex is cursed to own a cumulative weight collection for all eternity 😕. Even if the implementation is likely never to change, just seems like we shouldn't lock ourselves in if we don't have to.

dhardy commented 3 months ago

This is strange, because WeightedIndex itself derives Debug just fine.

You're right, it is strange. #[derive] has this weird rule that each generic parameter gains a bound on the derived trait (in this case, X: Debug). This was never actually correct, but often works... in this case, it turns out that WeightedIndex<X> needs both X: Debug (because of the first two fields) and also X::Sampler: Debug (for the last field).

I think in earlier versions of Rust this derive simply wouldn't have worked. You can read more on this topic here: https://smallcultfollowing.com/babysteps//blog/2022/04/12/implied-bounds-and-perfect-derive/

The really odd part is that now [derive(Debug)] on WeightedIndexIter gets the X: Debug bound (because that's weird standard practice), but not the X::Sampler: Debug bound (because that would require deeper analysis).

I think the best solution is to write an explicit impl with the correct bounds:

impl<'a, X> Debug for WeightedIndexIter<'a, X>
where
    X: SampleUniform + PartialOrd + Debug,
    X::Sampler: Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("WeightedIndexIter")
            .field("weighted_index", &self.weighted_index)
            .field("index", &self.index)
            .finish()
    }
}

Another solution is to use a variant of derive which allows explicit bounds. The following works, but there's no need for the extra dependency.

#[impl_tools::autoimpl(Debug where X: Debug, X::Sampler: Debug)]

I'm really not sure if it's worth it to expose that. My problem is solved without it,

Fair point. Lets leave that off then.

dhardy commented 3 months ago

Incidentally, the bounds on the Clone impl are also incorrect: they shouldn't require X: Clone. So we could just add this instead:

#[impl_tools::autoimpl(Clone)]
#[impl_tools::autoimpl(Debug where X: trait, X::Sampler: trait)]

That's another issue though.

MichaelOwenDyer commented 3 months ago

Thanks for the explanation and the interesting link :)

I went ahead and added manual implementations for Debug and Clone. A bit of boilerplate code, but I would say adding a new dependency should have its own PR later, if it comes to that 😃

Is there anything else to discuss?