Open RoyJames opened 1 year ago
Hey there! So you don't even know the number of classes in advance?
I think it would be ok to have some kind of adaptive method like you're describing. You want the classes to be distributed uniformally, right? Well, in that case, you could update the desired_dist
each time a new class appears. I think this could be worth having in the repo. But you need to do some exploration/experiments first.
I just implement the following code, using some stuff from River:
import collections
import random
from river import datasets
from river import utils
seen = collections.Counter()
sampling_rate = 0.5
actual_dist = collections.Counter()
n = 0
rng = random.Random(42)
dataset = datasets.ImageSegments()
for x, y in dataset:
actual_dist[y] += 1
n += 1
n_classes = len(actual_dist)
f = 1 / n_classes
g = actual_dist[y]
rate = sampling_rate * f / (g / n)
for _ in range(utils.random.poisson(rate, rng=rng)):
seen[y] += 1
seen
Counter({'path': 168,
'window': 176,
'grass': 157,
'cement': 153,
'brickface': 155,
'foliage': 164,
'sky': 198})
This is an adaptation of the balanced undersampling algorithm described here. Instead of specifying the desired dict beforehand, we calculate it on the fly, based on the number of classes seen so far. There are good reasons this algorithm should regulate well.
I just implement the following code, using some stuff from River:
import collections import random from river import datasets from river import utils seen = collections.Counter() sampling_rate = 0.5 actual_dist = collections.Counter() n = 0 rng = random.Random(42) dataset = datasets.ImageSegments() for x, y in dataset: actual_dist[y] += 1 n += 1 n_classes = len(actual_dist) f = 1 / n_classes g = actual_dist[y] rate = sampling_rate * f / (g / n) for _ in range(utils.random.poisson(rate, rng=rng)): seen[y] += 1 seen
Counter({'path': 168, 'window': 176, 'grass': 157, 'cement': 153, 'brickface': 155, 'foliage': 164, 'sky': 198})
This is an adaptation of the balanced undersampling algorithm described here. Instead of specifying the desired dict beforehand, we calculate it on the fly, based on the number of classes seen so far. There are good reasons this algorithm should regulate well.
Wow thank you for the quick response and even showing me an example. I think this exactly depicts what I wanted to do. The actual distribution here is promising. I will test it on much larger practical datasets with many more classes. Will report back later!
Yes but actually I realized that image segments dataset has a uniform distribution to begin with 😬
I went ahead and tried implementing this idea. In fact, the idea almost only requires a one-line change in the __iter__()
method, the same for all 3 types of sampling methods:
f = collections.defaultdict(lambda: 1.0 / len(g))
Or even simpler, we can just get rid of f
(namely desired_dist
) if we always want a balanced sampling, because in both Under and Over sampling, f
cancels out (i.e., when you evaluate f[y]/f[self._pivot]
), and in Hybrid sampling, it is merely a constant scaling factor.
But I still cannot get it to work in my situation. I believe this sampling strategy will suffer when there is a huge number of classes. A few observations:
ratio
is going to be very small when we have thousands of classes, such that all non-pivot classes are highly likely to be rejected. That makes sampling 10x slower and becomes unacceptable.random_poisson(g[self._pivot]/g[y])
to guide over-sampling. It seems to me that this should work just like when we knew all classes beforehand, but I may have overlooked some differences. One of my test runs gives ("total samples" below means the number of samples I have repeatedly sampled from the IterableDataset
(a WebDataset
), so the ideal count of each class should be 230):
total samples: 160000, total classes: 696
top 10 labels: [('01849ea601', 524), ('01859d4762', 508), ('018d016635', 496), ('0188543100', 484), ('017ce1e680', 474), ('00ec79213f', 472), ('00d4ca0109', 468), ('004881ef1a', 462), ('00629ef528', 461), ('018893291f', 458)]
least 10 labels: [('0174dbd520', 16), ('01afe4a35f', 22), ('01b8cbbac0', 25), ('01ecfd1a78', 31), ('01e14d77bf', 33), ('01b8478d5d', 33), ('004017ba82', 36), ('01da1f1a18', 41), ('00379062e6', 42), ('01c0e8b6ea', 51)]
sampling_rate=1.0
), the results seem more balanced, although it is 3x slower than over-sampling
total samples: 160000, total classes: 697
top 10 labels: [('01ffd9f2de', 349), ('007e630bab', 336), ('017456594c', 336), ('025b854813', 330), ('025f077598', 330), ('024a9aa8c9', 327), ('007a3dbb06', 327), ('003199b017', 325), ('02408ad599', 325), ('01333eefba', 319)]
least 10 labels: [('00c96b6fff', 104), ('001519f234', 105), ('01e14d77bf', 107), ('027114332c', 107), ('01bf64059b', 112), ('01da1f1a18', 120), ('003fb666a9', 121), ('0098f67c04', 121), ('01b8478d5d', 124), ('0112504d5e', 125)]
Sorry that I cannot give a minimum code to reproduce it because the dataset and associated modules are difficult to disentangle. To fully test it, a dummy unbalanced dataset with a high number of classes may need to be crafted.
Thanks to the author for patient answer! I also have a problem that the current solution is difficult to solve: I'm dealing with a very large WebDataset too, we need a batch where certain attribute labels are consistent. For example, we have 10 labels with uneven distribution, but we hope that each batch can uniformly return data with one consistent labels, due to the IterableDataset is too restrictive by not allowing the combination with samplers, we can only customize IterableDataset. It is estimated that there will be a caching mechanism, but we do not want to affect too much performance. I have come up with some solutions but they haven't been resolved, so could you please provide some suggestions? Any suggestions are appreciated.
Thanks @RoyJames for providing more context. I don't have any sparse time at the moment. But I'd love to come back to this later! But just to be clear, the HybridSampler
is working decently, right? The "only" issue is performance-wise, correct? If so, maybe we can optimize it.
To fully test it, a dummy unbalanced dataset with a high number of classes may need to be crafted.
@RoyJames I encourage you to share a function which simulates a dataset! This way any of us can validate whatever solution we come up with.
Also, I think it would be great if you could contribute what you did to this package. For example, it would be nice if all methods would now take an optional desired_dist
. If desired_dist=None
, then each method uses the trick you just did. See what I mean? Obviously, some documentation would be needed.
I have come up with some solutions but they haven't been resolved, so could you please provide some suggestions?
Hello @1073521013! Do you know the number of classes beforehand? If so, you should be able to use this package out-of-the-box. @RoyJames' issue is that the number wasn't known beforehand.
Perhaps I didn't express myself clearly, I know the number of classes before hand, but I need to return samples with the same label for each batch. The uniform sampling here is for the overall view rather than a specific batch
@MaxHalford I cannot conclude whether HybridSampler
works decently (as the actual distribution is still not quite balanced, but better than OverSampler
), but it could be that I have to sample much more to see whether it gives better results. I'll try to do that next.
But before so, I am confused about the way the actual_dist
is updated in all 3 samplers. I see self.actual_dist[y] += 1
invoked when a sample is fetched regardless of whether we discard it, not when it is yielded. And even if we yielded one sample multiple times in the Poisson process, the current code only increments its occurrence by 1. I think self.actual_dist[y]
should be updated whenever a sample is yielded anywhere in the code, for example:
for _ in range(utils.random_poisson(rate, rng=self.rng)):
self.actual_dist[y] += 1
yield x, y
And when we initially see a new class, we would do the following to unconditionally sample it
for x, y in self.dataset:
if self.actual_dist[y] == 0:
self.actual_dist[y] += 1
yield x, y
continue
Am I understanding it wrong? Because it seems your current implementation already works well on your examples, even though the actual_dist
does not track the "actual distribution" we've sampled.
I cannot conclude whether HybridSampler works decently (as the actual distribution is still not quite balanced, but better than OverSampler), but it could be that I have to sample much more to see whether it gives better results. I'll try to do that next.
My bad, I had misread your results.
Am I understanding it wrong? Because it seems your current implementation already works well on your examples, even though the actual_dist does not track the "actual distribution" we've sampled.
No so actual_dist
tracks the distribution of the data, regardless if the sample is kept or not. But this could be wrong! I haven't been familiar with this code since 3 years now. Feel free to challenge it :). That's what open-source is about.
I cannot conclude whether HybridSampler works decently (as the actual distribution is still not quite balanced, but better than OverSampler), but it could be that I have to sample much more to see whether it gives better results. I'll try to do that next.
My bad, I had misread your results.
Am I understanding it wrong? Because it seems your current implementation already works well on your examples, even though the actual_dist does not track the "actual distribution" we've sampled.
No so
actual_dist
tracks the distribution of the data, regardless if the sample is kept or not. But this could be wrong! I haven't been familiar with this code since 3 years now. Feel free to challenge it :). That's what open-source is about.
I first kept actual_dist
as is. And after I scale up the sampling count by 10x, I'm seeing a more balanced final distribution with HybridSampler
(should expect 2952 samples per class):
total samples: 1600000, total classes: 542
top 50 labels: [('01c2cafbdc', 3199), ('01a91c3465', 3173), ('017e0fcfe7', 3171), ('00328607b5', 3157), ('01b0737506', 3151), ('004881ef1a', 3151), ('01bd414599', 3150), ('0064f09b4b', 3146), ('00f458c1f3', 3144), ('01aed7c251', 3140), ('004459128f', 3138), ('01a44ed5d1', 3137), ('01ddd70b02', 3137), ('01e49440ac', 3135), ('01b7a16a23', 3134), ('01d57cb84b', 3134), ('01dd3c974c', 3130), ('01e1bfd863', 3125), ('00d7497d66', 3123), ('019c3f0212', 3123), ('01c804653a', 3117), ('01859d4762', 3115), ('0018b42a2d', 3114), ('01d132f05b', 3114), ('01daa9ba1d', 3114), ('01a3cd8acf', 3112), ('0032f55ac2', 3111), ('0191833fd8', 3110), ('0081e23fdf', 3108), ('007a6fb8bf', 3105), ('000cf301a9', 3105), ('00009949b3', 3103), ('01a9230854', 3103), ('01c1c5b414', 3102), ('01dc67f8df', 3102), ('00e2b3f397', 3100), ('0066f97ac5', 3099), ('0061fc0c7b', 3099), ('01a839a670', 3098), ('01bc2f9bfe', 3095), ('00d5637ee7', 3095), ('01ead8a87e', 3095), ('0100a7d69b', 3092), ('01c681ad16', 3091), ('019aa38376', 3090), ('01bafad405', 3090), ('009aa75e72', 3089), ('0192f0a236', 3089), ('0016246a91', 3089), ('005f2c6805', 3088)]
least 50 labels: [('0112504d5e', 2586), ('0111e2768b', 2673), ('0159ebc5e0', 2684), ('010d53c3c2', 2686), ('01115e41ff', 2699), ('01107e6b36', 2702), ('0131af3e9c', 2704), ('0151a11811', 2709), ('01480c48a9', 2710), ('0116a9ae8c', 2711), ('0157c27ca7', 2716), ('0131ef5c3b', 2717), ('010465487d', 2718), ('0136b8a970', 2727), ('012210b09f', 2727), ('010d1283c2', 2732), ('01426dadfe', 2733), ('01029adb23', 2736), ('014e914a44', 2740), ('0105e3e139', 2743), ('0174dbd520', 2746), ('0111e1d755', 2750), ('0172f7df10', 2753), ('0178e578f0', 2755), ('0120500f2a', 2757), ('010ec5280a', 2757), ('012c835b66', 2763), ('01117e7ea9', 2765), ('0138f8d13e', 2767), ('0161391129', 2768), ('010d81932f', 2768), ('0168ec8355', 2768), ('0108d3d094', 2769), ('0167523ae3', 2770), ('0137e6af59', 2773), ('010bfc0e69', 2774), ('01252e2548', 2776), ('0134a425df', 2777), ('014db700e6', 2777), ('0152f8da86', 2781), ('0105f646ac', 2783), ('016d46e43b', 2783), ('0167aad6e7', 2787), ('0109960a7e', 2787), ('015b9d3c8f', 2788), ('016c24255f', 2790), ('01363e90e7', 2791), ('016646916a', 2792), ('0114ca18d7', 2795), ('0166552452', 2796)]
For comparison, I also ran this for OverSampler
again, and it HybridSampler
does outperform OverSampler
consistently in my tests:
total samples: 1600000, total classes: 542
top 50 labels: [('017ce31ae5', 3357), ('018b7a1c97', 3356), ('01c3eaf587', 3346), ('000e535568', 3343), ('018893291f', 3338), ('01e49440ac', 3334), ('0109960a7e', 3329), ('0170ac6f80', 3327), ('018d016635', 3319), ('01533ad263', 3319), ('0142ff4c4d', 3318), ('01bd414599', 3317), ('017e0fcfe7', 3314), ('018e01ed7e', 3314), ('01bc2f9bfe', 3312), ('01849ea601', 3309), ('01a0900de8', 3306), ('00015e9c35', 3302), ('00143475c5', 3301), ('018e4f7bde', 3300), ('012be72302', 3300), ('014c5ee427', 3299), ('0108d3d094', 3298), ('0114ca18d7', 3297), ('01a9230854', 3296), ('003199b017', 3294), ('010bfc0e69', 3292), ('0063745b53', 3291), ('017ce1e680', 3287), ('01a63af1aa', 3287), ('012029ed67', 3287), ('0167255c01', 3285), ('0192f0a236', 3284), ('017456594c', 3279), ('019d7533fd', 3277), ('01859d4762', 3276), ('016646916a', 3276), ('004a3c0c3b', 3274), ('019b934834', 3273), ('01333eefba', 3271), ('01ddd70b02', 3269), ('016a08ddd8', 3268), ('016e15e58e', 3265), ('01de083c06', 3264), ('01a4d12dcb', 3263), ('01715b0958', 3262), ('0131c14762', 3260), ('01daa9ba1d', 3257), ('00328607b5', 3255), ('01c7476deb', 3253)]
least 50 labels: [('00c3f0e7c6', 2265), ('00c96b6fff', 2289), ('0098f67c04', 2295), ('009e6fb329', 2321), ('00c151620f', 2322), ('00ff84f57a', 2340), ('00e77ea0f4', 2345), ('00e3963727', 2346), ('00e86396ae', 2375), ('00af360965', 2376), ('00bc782bfe', 2380), ('00e7fec432', 2381), ('00eaa8925b', 2383), ('00b4f0212e', 2399), ('00dd99b22a', 2403), ('00ad62eca9', 2410), ('00cd57da93', 2416), ('00e49f922a', 2423), ('00c9c17011', 2423), ('00cf679539', 2428), ('008ff4dcf9', 2436), ('00f5d86abe', 2441), ('0091e37358', 2442), ('00d443b51c', 2446), ('00f318fe2e', 2447), ('00ec9305e8', 2460), ('00e1a7623a', 2462), ('00bac0ff14', 2464), ('00d5637ee7', 2465), ('00a46a27af', 2473), ('00b244d2ff', 2475), ('009677a424', 2477), ('00a693cb8d', 2478), ('00f71751f9', 2484), ('00936e5750', 2484), ('00d795792d', 2487), ('00ac2a807b', 2494), ('00f35cf868', 2498), ('00b18a4112', 2498), ('00ba7ba929', 2498), ('00bbe8b58d', 2499), ('00ff6dcff5', 2501), ('00da019987', 2503), ('00de1f41bd', 2504), ('00e2df5a1b', 2505), ('00dd785acc', 2505), ('00b99ee0d3', 2507), ('00b17df7f8', 2509), ('00a76e27ba', 2516), ('00b130b141', 2519)]
I guess HybridSampler
rejects more samples, so it still takes 2.6x more time than OverSampler
. But even that was not a major concern for me because training at such a large scale usually has its bottleneck at computing rather than data loading.
Next, when I change actual_dist
to track the actual sampled distribution, the results get apparently more unbalanced. So I feel your current implementation is the correct/better one.
I would love to contribute back the balanced sampling method for unknown distributions, but there have to be more rigorous tests (as a unit test in this repo maybe?) before we do it. Let's see when one of us (or someone else) can find some time. API-wise, I'm thinking of two options:
balanced: bool
flag to __init__()
to indicate we want a balanced sampling, and ignore desired_dist
if set to balanced==True
.desired_dist
to be None
(default) or empty. And in these cases, we assume the user wants a balanced sampling regardless of what data is given.Thanks for that recap @RoyJames, very handy.
On my end, my intuition tells me you could improve the algorithm if you have a rough idea of the number of classes that will be seen. So something like this:
import collections
import random
from river import datasets
from river import utils
seen = collections.Counter()
sampling_rate = 0.5
actual_dist = collections.Counter()
n = 0
rng = random.Random(42)
expected_n_classes = 100
dataset = datasets.ImageSegments()
for x, y in dataset:
actual_dist[y] += 1
n += 1
n_classes = max(expected_n_classes, len(actual_dist))
f = 1 / n_classes
g = actual_dist[y]
rate = sampling_rate * f / (g / n)
for _ in range(utils.random.poisson(rate, rng=rng)):
seen[y] += 1
seen
Let me know if that helps!
Nice work! I was about to use this handy tool until I realize my problem was even trickier.
I'm dealing with a very large (TB scale) WebDataset, which also inherits IterableDataset. And the dataset can keep growing. My goal is to do balanced sampling based on some attributes of my samples. In other words, I want to have N classes with each of them having equal weights. It would be straightforward to do with this resample tool if I knew all the possible classes. However, that requires me to iterate through the whole dataset, which can take hours - which would have been fine if I only do it once, but I may have to do it over and over since my dataset will grow in size in the future (and new classes will come). I'm also aware that one workaround is to manage it out of the loop by maintaining an incremental list of classes on disk.
Still, it would be even better if this can be handled during training. I was thinking of building the
desired_dist
dynamically by initializing it with an empty dict and adding unseen classes to it with equal constant weights on-the-fly. It seems this might work but I've not tested so. Do you think this is something worth having in the repo? And do you see any caveats of doing so? Any suggestions are appreciated.