MaxHalford / pytorch-resample

🎲 Iterable dataset resampling in PyTorch
MIT License
91 stars 5 forks source link

Weighted sampling without knowing all classes beforehand #4

Open RoyJames opened 1 year ago

RoyJames commented 1 year ago

Nice work! I was about to use this handy tool until I realize my problem was even trickier.

I'm dealing with a very large (TB scale) WebDataset, which also inherits IterableDataset. And the dataset can keep growing. My goal is to do balanced sampling based on some attributes of my samples. In other words, I want to have N classes with each of them having equal weights. It would be straightforward to do with this resample tool if I knew all the possible classes. However, that requires me to iterate through the whole dataset, which can take hours - which would have been fine if I only do it once, but I may have to do it over and over since my dataset will grow in size in the future (and new classes will come). I'm also aware that one workaround is to manage it out of the loop by maintaining an incremental list of classes on disk.

Still, it would be even better if this can be handled during training. I was thinking of building the desired_dist dynamically by initializing it with an empty dict and adding unseen classes to it with equal constant weights on-the-fly. It seems this might work but I've not tested so. Do you think this is something worth having in the repo? And do you see any caveats of doing so? Any suggestions are appreciated.

MaxHalford commented 1 year ago

Hey there! So you don't even know the number of classes in advance?

I think it would be ok to have some kind of adaptive method like you're describing. You want the classes to be distributed uniformally, right? Well, in that case, you could update the desired_dist each time a new class appears. I think this could be worth having in the repo. But you need to do some exploration/experiments first.

MaxHalford commented 1 year ago

I just implement the following code, using some stuff from River:

import collections
import random
from river import datasets
from river import utils

seen = collections.Counter()

sampling_rate = 0.5
actual_dist = collections.Counter()
n = 0
rng = random.Random(42)

dataset = datasets.ImageSegments()

for x, y in dataset:
    actual_dist[y] += 1
    n += 1
    n_classes = len(actual_dist)
    f = 1 / n_classes
    g = actual_dist[y]

    rate = sampling_rate * f / (g / n)

    for _ in range(utils.random.poisson(rate, rng=rng)):
        seen[y] += 1

seen
Counter({'path': 168,
         'window': 176,
         'grass': 157,
         'cement': 153,
         'brickface': 155,
         'foliage': 164,
         'sky': 198})

This is an adaptation of the balanced undersampling algorithm described here. Instead of specifying the desired dict beforehand, we calculate it on the fly, based on the number of classes seen so far. There are good reasons this algorithm should regulate well.

RoyJames commented 1 year ago

I just implement the following code, using some stuff from River:

import collections
import random
from river import datasets
from river import utils

seen = collections.Counter()

sampling_rate = 0.5
actual_dist = collections.Counter()
n = 0
rng = random.Random(42)

dataset = datasets.ImageSegments()

for x, y in dataset:
    actual_dist[y] += 1
    n += 1
    n_classes = len(actual_dist)
    f = 1 / n_classes
    g = actual_dist[y]

    rate = sampling_rate * f / (g / n)

    for _ in range(utils.random.poisson(rate, rng=rng)):
        seen[y] += 1

seen
Counter({'path': 168,
         'window': 176,
         'grass': 157,
         'cement': 153,
         'brickface': 155,
         'foliage': 164,
         'sky': 198})

This is an adaptation of the balanced undersampling algorithm described here. Instead of specifying the desired dict beforehand, we calculate it on the fly, based on the number of classes seen so far. There are good reasons this algorithm should regulate well.

Wow thank you for the quick response and even showing me an example. I think this exactly depicts what I wanted to do. The actual distribution here is promising. I will test it on much larger practical datasets with many more classes. Will report back later!

MaxHalford commented 1 year ago

Yes but actually I realized that image segments dataset has a uniform distribution to begin with 😬

RoyJames commented 1 year ago

I went ahead and tried implementing this idea. In fact, the idea almost only requires a one-line change in the __iter__() method, the same for all 3 types of sampling methods: f = collections.defaultdict(lambda: 1.0 / len(g)) Or even simpler, we can just get rid of f (namely desired_dist) if we always want a balanced sampling, because in both Under and Over sampling, f cancels out (i.e., when you evaluate f[y]/f[self._pivot]), and in Hybrid sampling, it is merely a constant scaling factor.

But I still cannot get it to work in my situation. I believe this sampling strategy will suffer when there is a huge number of classes. A few observations:

  1. Under-sampling is unusable for too many classes because reject sampling will gradually slow down, as the number of classes increases. The ratio is going to be very small when we have thousands of classes, such that all non-pivot classes are highly likely to be rejected. That makes sampling 10x slower and becomes unacceptable.
  2. Over-sampling in my case does not give a very balanced sampling. To recap, we are essentially using random_poisson(g[self._pivot]/g[y]) to guide over-sampling. It seems to me that this should work just like when we knew all classes beforehand, but I may have overlooked some differences. One of my test runs gives ("total samples" below means the number of samples I have repeatedly sampled from the IterableDataset (a WebDataset), so the ideal count of each class should be 230):
    total samples: 160000, total classes: 696
    top 10 labels: [('01849ea601', 524), ('01859d4762', 508), ('018d016635', 496), ('0188543100', 484), ('017ce1e680', 474), ('00ec79213f', 472), ('00d4ca0109', 468), ('004881ef1a', 462), ('00629ef528', 461), ('018893291f', 458)]
    least 10 labels: [('0174dbd520', 16), ('01afe4a35f', 22), ('01b8cbbac0', 25), ('01ecfd1a78', 31), ('01e14d77bf', 33), ('01b8478d5d', 33), ('004017ba82', 36), ('01da1f1a18', 41), ('00379062e6', 42), ('01c0e8b6ea', 51)]
  3. For the same test in 2, when I use hybrid sampling (with sampling_rate=1.0), the results seem more balanced, although it is 3x slower than over-sampling
    total samples: 160000, total classes: 697
    top 10 labels: [('01ffd9f2de', 349), ('007e630bab', 336), ('017456594c', 336), ('025b854813', 330), ('025f077598', 330), ('024a9aa8c9', 327), ('007a3dbb06', 327), ('003199b017', 325), ('02408ad599', 325), ('01333eefba', 319)]
    least 10 labels: [('00c96b6fff', 104), ('001519f234', 105), ('01e14d77bf', 107), ('027114332c', 107), ('01bf64059b', 112), ('01da1f1a18', 120), ('003fb666a9', 121), ('0098f67c04', 121), ('01b8478d5d', 124), ('0112504d5e', 125)]

Sorry that I cannot give a minimum code to reproduce it because the dataset and associated modules are difficult to disentangle. To fully test it, a dummy unbalanced dataset with a high number of classes may need to be crafted.

1073521013 commented 1 year ago

Thanks to the author for patient answer! I also have a problem that the current solution is difficult to solve: I'm dealing with a very large WebDataset too, we need a batch where certain attribute labels are consistent. For example, we have 10 labels with uneven distribution, but we hope that each batch can uniformly return data with one consistent labels, due to the IterableDataset is too restrictive by not allowing the combination with samplers, we can only customize IterableDataset. It is estimated that there will be a caching mechanism, but we do not want to affect too much performance. I have come up with some solutions but they haven't been resolved, so could you please provide some suggestions? Any suggestions are appreciated.

MaxHalford commented 1 year ago

Thanks @RoyJames for providing more context. I don't have any sparse time at the moment. But I'd love to come back to this later! But just to be clear, the HybridSampler is working decently, right? The "only" issue is performance-wise, correct? If so, maybe we can optimize it.

To fully test it, a dummy unbalanced dataset with a high number of classes may need to be crafted.

@RoyJames I encourage you to share a function which simulates a dataset! This way any of us can validate whatever solution we come up with.

Also, I think it would be great if you could contribute what you did to this package. For example, it would be nice if all methods would now take an optional desired_dist. If desired_dist=None, then each method uses the trick you just did. See what I mean? Obviously, some documentation would be needed.

I have come up with some solutions but they haven't been resolved, so could you please provide some suggestions?

Hello @1073521013! Do you know the number of classes beforehand? If so, you should be able to use this package out-of-the-box. @RoyJames' issue is that the number wasn't known beforehand.

1073521013 commented 1 year ago

Perhaps I didn't express myself clearly, I know the number of classes before hand, but I need to return samples with the same label for each batch. The uniform sampling here is for the overall view rather than a specific batch

RoyJames commented 1 year ago

@MaxHalford I cannot conclude whether HybridSampler works decently (as the actual distribution is still not quite balanced, but better than OverSampler), but it could be that I have to sample much more to see whether it gives better results. I'll try to do that next.

But before so, I am confused about the way the actual_dist is updated in all 3 samplers. I see self.actual_dist[y] += 1 invoked when a sample is fetched regardless of whether we discard it, not when it is yielded. And even if we yielded one sample multiple times in the Poisson process, the current code only increments its occurrence by 1. I think self.actual_dist[y] should be updated whenever a sample is yielded anywhere in the code, for example:

for _ in range(utils.random_poisson(rate, rng=self.rng)):
    self.actual_dist[y] += 1
    yield x, y

And when we initially see a new class, we would do the following to unconditionally sample it

for x, y in self.dataset:
    if self.actual_dist[y] == 0:
        self.actual_dist[y] += 1
        yield x, y
        continue

Am I understanding it wrong? Because it seems your current implementation already works well on your examples, even though the actual_dist does not track the "actual distribution" we've sampled.

MaxHalford commented 1 year ago

I cannot conclude whether HybridSampler works decently (as the actual distribution is still not quite balanced, but better than OverSampler), but it could be that I have to sample much more to see whether it gives better results. I'll try to do that next.

My bad, I had misread your results.

Am I understanding it wrong? Because it seems your current implementation already works well on your examples, even though the actual_dist does not track the "actual distribution" we've sampled.

No so actual_dist tracks the distribution of the data, regardless if the sample is kept or not. But this could be wrong! I haven't been familiar with this code since 3 years now. Feel free to challenge it :). That's what open-source is about.

RoyJames commented 1 year ago

I cannot conclude whether HybridSampler works decently (as the actual distribution is still not quite balanced, but better than OverSampler), but it could be that I have to sample much more to see whether it gives better results. I'll try to do that next.

My bad, I had misread your results.

Am I understanding it wrong? Because it seems your current implementation already works well on your examples, even though the actual_dist does not track the "actual distribution" we've sampled.

No so actual_dist tracks the distribution of the data, regardless if the sample is kept or not. But this could be wrong! I haven't been familiar with this code since 3 years now. Feel free to challenge it :). That's what open-source is about.

I first kept actual_dist as is. And after I scale up the sampling count by 10x, I'm seeing a more balanced final distribution with HybridSampler (should expect 2952 samples per class):

total samples: 1600000, total classes: 542
top 50 labels: [('01c2cafbdc', 3199), ('01a91c3465', 3173), ('017e0fcfe7', 3171), ('00328607b5', 3157), ('01b0737506', 3151), ('004881ef1a', 3151), ('01bd414599', 3150), ('0064f09b4b', 3146), ('00f458c1f3', 3144), ('01aed7c251', 3140), ('004459128f', 3138), ('01a44ed5d1', 3137), ('01ddd70b02', 3137), ('01e49440ac', 3135), ('01b7a16a23', 3134), ('01d57cb84b', 3134), ('01dd3c974c', 3130), ('01e1bfd863', 3125), ('00d7497d66', 3123), ('019c3f0212', 3123), ('01c804653a', 3117), ('01859d4762', 3115), ('0018b42a2d', 3114), ('01d132f05b', 3114), ('01daa9ba1d', 3114), ('01a3cd8acf', 3112), ('0032f55ac2', 3111), ('0191833fd8', 3110), ('0081e23fdf', 3108), ('007a6fb8bf', 3105), ('000cf301a9', 3105), ('00009949b3', 3103), ('01a9230854', 3103), ('01c1c5b414', 3102), ('01dc67f8df', 3102), ('00e2b3f397', 3100), ('0066f97ac5', 3099), ('0061fc0c7b', 3099), ('01a839a670', 3098), ('01bc2f9bfe', 3095), ('00d5637ee7', 3095), ('01ead8a87e', 3095), ('0100a7d69b', 3092), ('01c681ad16', 3091), ('019aa38376', 3090), ('01bafad405', 3090), ('009aa75e72', 3089), ('0192f0a236', 3089), ('0016246a91', 3089), ('005f2c6805', 3088)]
least 50 labels: [('0112504d5e', 2586), ('0111e2768b', 2673), ('0159ebc5e0', 2684), ('010d53c3c2', 2686), ('01115e41ff', 2699), ('01107e6b36', 2702), ('0131af3e9c', 2704), ('0151a11811', 2709), ('01480c48a9', 2710), ('0116a9ae8c', 2711), ('0157c27ca7', 2716), ('0131ef5c3b', 2717), ('010465487d', 2718), ('0136b8a970', 2727), ('012210b09f', 2727), ('010d1283c2', 2732), ('01426dadfe', 2733), ('01029adb23', 2736), ('014e914a44', 2740), ('0105e3e139', 2743), ('0174dbd520', 2746), ('0111e1d755', 2750), ('0172f7df10', 2753), ('0178e578f0', 2755), ('0120500f2a', 2757), ('010ec5280a', 2757), ('012c835b66', 2763), ('01117e7ea9', 2765), ('0138f8d13e', 2767), ('0161391129', 2768), ('010d81932f', 2768), ('0168ec8355', 2768), ('0108d3d094', 2769), ('0167523ae3', 2770), ('0137e6af59', 2773), ('010bfc0e69', 2774), ('01252e2548', 2776), ('0134a425df', 2777), ('014db700e6', 2777), ('0152f8da86', 2781), ('0105f646ac', 2783), ('016d46e43b', 2783), ('0167aad6e7', 2787), ('0109960a7e', 2787), ('015b9d3c8f', 2788), ('016c24255f', 2790), ('01363e90e7', 2791), ('016646916a', 2792), ('0114ca18d7', 2795), ('0166552452', 2796)]

For comparison, I also ran this for OverSampler again, and it HybridSampler does outperform OverSampler consistently in my tests:

total samples: 1600000, total classes: 542
top 50 labels: [('017ce31ae5', 3357), ('018b7a1c97', 3356), ('01c3eaf587', 3346), ('000e535568', 3343), ('018893291f', 3338), ('01e49440ac', 3334), ('0109960a7e', 3329), ('0170ac6f80', 3327), ('018d016635', 3319), ('01533ad263', 3319), ('0142ff4c4d', 3318), ('01bd414599', 3317), ('017e0fcfe7', 3314), ('018e01ed7e', 3314), ('01bc2f9bfe', 3312), ('01849ea601', 3309), ('01a0900de8', 3306), ('00015e9c35', 3302), ('00143475c5', 3301), ('018e4f7bde', 3300), ('012be72302', 3300), ('014c5ee427', 3299), ('0108d3d094', 3298), ('0114ca18d7', 3297), ('01a9230854', 3296), ('003199b017', 3294), ('010bfc0e69', 3292), ('0063745b53', 3291), ('017ce1e680', 3287), ('01a63af1aa', 3287), ('012029ed67', 3287), ('0167255c01', 3285), ('0192f0a236', 3284), ('017456594c', 3279), ('019d7533fd', 3277), ('01859d4762', 3276), ('016646916a', 3276), ('004a3c0c3b', 3274), ('019b934834', 3273), ('01333eefba', 3271), ('01ddd70b02', 3269), ('016a08ddd8', 3268), ('016e15e58e', 3265), ('01de083c06', 3264), ('01a4d12dcb', 3263), ('01715b0958', 3262), ('0131c14762', 3260), ('01daa9ba1d', 3257), ('00328607b5', 3255), ('01c7476deb', 3253)]
least 50 labels: [('00c3f0e7c6', 2265), ('00c96b6fff', 2289), ('0098f67c04', 2295), ('009e6fb329', 2321), ('00c151620f', 2322), ('00ff84f57a', 2340), ('00e77ea0f4', 2345), ('00e3963727', 2346), ('00e86396ae', 2375), ('00af360965', 2376), ('00bc782bfe', 2380), ('00e7fec432', 2381), ('00eaa8925b', 2383), ('00b4f0212e', 2399), ('00dd99b22a', 2403), ('00ad62eca9', 2410), ('00cd57da93', 2416), ('00e49f922a', 2423), ('00c9c17011', 2423), ('00cf679539', 2428), ('008ff4dcf9', 2436), ('00f5d86abe', 2441), ('0091e37358', 2442), ('00d443b51c', 2446), ('00f318fe2e', 2447), ('00ec9305e8', 2460), ('00e1a7623a', 2462), ('00bac0ff14', 2464), ('00d5637ee7', 2465), ('00a46a27af', 2473), ('00b244d2ff', 2475), ('009677a424', 2477), ('00a693cb8d', 2478), ('00f71751f9', 2484), ('00936e5750', 2484), ('00d795792d', 2487), ('00ac2a807b', 2494), ('00f35cf868', 2498), ('00b18a4112', 2498), ('00ba7ba929', 2498), ('00bbe8b58d', 2499), ('00ff6dcff5', 2501), ('00da019987', 2503), ('00de1f41bd', 2504), ('00e2df5a1b', 2505), ('00dd785acc', 2505), ('00b99ee0d3', 2507), ('00b17df7f8', 2509), ('00a76e27ba', 2516), ('00b130b141', 2519)]

I guess HybridSampler rejects more samples, so it still takes 2.6x more time than OverSampler. But even that was not a major concern for me because training at such a large scale usually has its bottleneck at computing rather than data loading.

Next, when I change actual_dist to track the actual sampled distribution, the results get apparently more unbalanced. So I feel your current implementation is the correct/better one.

I would love to contribute back the balanced sampling method for unknown distributions, but there have to be more rigorous tests (as a unit test in this repo maybe?) before we do it. Let's see when one of us (or someone else) can find some time. API-wise, I'm thinking of two options:

  1. Add a balanced: bool flag to __init__() to indicate we want a balanced sampling, and ignore desired_dist if set to balanced==True.
  2. Allow desired_dist to be None (default) or empty. And in these cases, we assume the user wants a balanced sampling regardless of what data is given.
MaxHalford commented 1 year ago

Thanks for that recap @RoyJames, very handy.

On my end, my intuition tells me you could improve the algorithm if you have a rough idea of the number of classes that will be seen. So something like this:

import collections
import random
from river import datasets
from river import utils

seen = collections.Counter()

sampling_rate = 0.5
actual_dist = collections.Counter()
n = 0
rng = random.Random(42)
expected_n_classes = 100

dataset = datasets.ImageSegments()

for x, y in dataset:
    actual_dist[y] += 1
    n += 1
    n_classes = max(expected_n_classes, len(actual_dist))
    f = 1 / n_classes
    g = actual_dist[y]

    rate = sampling_rate * f / (g / n)

    for _ in range(utils.random.poisson(rate, rng=rng)):
        seen[y] += 1

seen

Let me know if that helps!