pavlin-policar / openTSNE

Extensible, parallel implementations of t-SNE
https://opentsne.rtfd.io
BSD 3-Clause "New" or "Revised" License
1.48k stars 165 forks source link

FFT parameters and runtime for very expanded embeddings #174

Closed dkobak closed 3 years ago

dkobak commented 3 years ago

I have been doing some experiments on convergence and running t-SNE for many more iterations than I normally do. And I again noticed something that I used to see every now and then: the runtime jumps wildly between "epochs" of 50 iterations. This only happens when the embedding is very expanded and so FFT gets really slow. Look:

Iteration   50, KL divergence 4.8674, 50 iterations in 1.8320 sec
Iteration  100, KL divergence 4.3461, 50 iterations in 1.8760 sec
Iteration  150, KL divergence 4.0797, 50 iterations in 2.6252 sec
Iteration  200, KL divergence 3.9082, 50 iterations in 4.5062 sec
Iteration  250, KL divergence 3.7864, 50 iterations in 5.4258 sec
Iteration  300, KL divergence 3.6957, 50 iterations in 7.2500 sec
Iteration  350, KL divergence 3.6259, 50 iterations in 9.0705 sec
Iteration  400, KL divergence 3.5711, 50 iterations in 10.1077 sec
Iteration  450, KL divergence 3.5271, 50 iterations in 12.2412 sec
Iteration  500, KL divergence 3.4909, 50 iterations in 13.6440 sec
Iteration  550, KL divergence 3.4604, 50 iterations in 14.6127 sec
Iteration  600, KL divergence 3.4356, 50 iterations in 17.2364 sec
Iteration  650, KL divergence 3.4143, 50 iterations in 17.6973 sec
Iteration  700, KL divergence 3.3986, 50 iterations in 27.9720 sec
Iteration  750, KL divergence 3.3914, 50 iterations in 34.0480 sec
Iteration  800, KL divergence 3.3863, 50 iterations in 34.4572 sec
Iteration  850, KL divergence 3.3820, 50 iterations in 36.9247 sec
Iteration  900, KL divergence 3.3779, 50 iterations in 47.0994 sec
Iteration  950, KL divergence 3.3737, 50 iterations in 40.8424 sec
Iteration 1000, KL divergence 3.3696, 50 iterations in 62.1549 sec
Iteration 1050, KL divergence 3.3653, 50 iterations in 30.6310 sec
Iteration 1100, KL divergence 3.3613, 50 iterations in 44.9781 sec
Iteration 1150, KL divergence 3.3571, 50 iterations in 36.9257 sec
Iteration 1200, KL divergence 3.3531, 50 iterations in 66.3830 sec
Iteration 1250, KL divergence 3.3493, 50 iterations in 37.7215 sec
Iteration 1300, KL divergence 3.3457, 50 iterations in 33.7942 sec
Iteration 1350, KL divergence 3.3421, 50 iterations in 33.7507 sec
Iteration 1400, KL divergence 3.3387, 50 iterations in 59.2065 sec
Iteration 1450, KL divergence 3.3354, 50 iterations in 36.3713 sec
Iteration 1500, KL divergence 3.3323, 50 iterations in 39.1894 sec
Iteration 1550, KL divergence 3.3293, 50 iterations in 67.3239 sec
Iteration 1600, KL divergence 3.3265, 50 iterations in 33.9837 sec
Iteration 1650, KL divergence 3.3238, 50 iterations in 63.5015 sec

For the record, this is on full MNIST with uniform k=15 affinity, n_jobs=-1. Note that after it gets to 30 seconds / 50 iterations, it starts fluctuating between 30 and 60. This does not make sense.

I suspect it may be related to how interpolation params are chosen depending on the grid size. Can it be that those heuristics may need improvement?

Incidentally, can it be that the interpolation params can be relaxed once the embedding becomes very large (e.g. span larger than [-100,100]) so that optimisation runs faster without -- perhaps! -- compromising the approximation too much?

CCing to @linqiaozhi.

dkobak commented 3 years ago

I found where it came up before: https://github.com/KlugerLab/FIt-SNE/issues/67. However, it stayed unresolved there (I was not motivated enough to investigate and closed that issue without diagnosing the problem).

linqiaozhi commented 3 years ago

Thanks for the CC. My comments on that thread apply here--it's possible you are bouncing between number of interpolation nodes. My first step in diagnosing would be to print the number of nodes, embedding size, and duration of time at each iteration.

When you run that many iterations with a large step size, the embedding becomes really large--and we never really tested it in that setting. It may be that the preset number of nodes is not appropriate for that size of an embedding.

dkobak commented 3 years ago

Hi George. The possible number of interpolation nodes is quite finely distributed though:

cdef list recommended_boxes = [
        25, 36, 50, 55, 60, 65, 70, 75, 80, 85, 90, 96, 100, 110, 120, 130, 140, 150, 175, 200
    ]

-- can it be that bouncing back and forth between two neighbouring values affects the runtime by 2x?

My first step in diagnosing would be to print the number of nodes, embedding size, and duration of time at each iteration.

Makes sense. I might give it a try!

linqiaozhi commented 3 years ago

-- can it be that bouncing back and forth between two neighbouring values affects the runtime by 2x?

You're right--seems strange--but it's the first thing I'd rule out.

The other aspect that is strange is that these times are averaged over 50 iterations. If it was just on the border, bouncing between sizes with every few iterations, that effect should average out.

dkobak commented 3 years ago

I am afraid this won't clarify much, but here is a plot of embedding span (max-min) on every iteration and runtime per iteration.

runtime-oscillations

dkobak commented 3 years ago

Waaait a second! The runtime starts behaving wildly after embedding span crosses 200. And that's exactly the last value in the recommended_boxes! Can this be the culprit???

dkobak commented 3 years ago

So I changed the list to

        25, 36, 50, 60, 70, 75, 80, 90, 96, 100, 120, 140, 150, 175, 200,
        250, 300, 350, 400, 450, 500, 1000, 5000, 10000 

(I took out values that were multiples of 11 and 13 because I wasn't sure if they don't slow FFTW down and there was also 85 in the list which has factor 17...). Here is the result:

runtime-oscillations-fix

It does look a bit better (smoother?) than before, but the weird steps after the span crosses ~200 are still there.

dkobak commented 3 years ago

What's weird is that those steps seem all to have length 18 iterations...

dkobak commented 3 years ago

Apologies for spamming everybody, but turns out I updated the recommended_boxes list in the function for 1d FFT but not for 2f FFT. Facepalm. Now I updated it in both places and rerun. Voilà:

runtime-oscillations-fix

dkobak commented 3 years ago

Incidentally, can it be that the interpolation params can be relaxed once the embedding becomes very large (e.g. span larger than [-100,100]) so that optimisation runs faster without -- perhaps! -- compromising the approximation too much?

Last thing -- and then I am off for today. I tried to clip the number of boxes to 100, i.e. if the recommended value was above 100 then I still set it to 100. This made the KL divergence decrease much slower than without clipping. So I conclude that it's a bad idea.

pavlin-policar commented 3 years ago

I don't think these are good sizes for the boxes. From the FFTW docs:

FFTW is best at handling sizes of the form 2^a 3^b 5^c 7^d 11^e 13^f, where e+f is either 0 or 1, and the other exponents are arbitrary. Other sizes are computed by means of a slow, general-purpose algorithm (which nevertheless retains O(n log n) performance even for prime sizes). (It is possible to customize FFTW for different array sizes; see Installation and Customization.) Transforms whose sizes are powers of 2 are especially fast, and it is generally beneficial for the last dimension of an r2c/c2r transform to be even.

So, we may just want to generate a larger list of predefined numbers that fit this formula. E.g.

l = set()
for a in range(10):
     for b in range(10):
         for c in range(10):
             for d in range(10):
                 l.add(2**a * 3**b * 5**c * 7**d * 11**0 * 13**0)
                 l.add(2**a * 3**b * 5**c * 7**d * 11**0 * 13**1)
                 l.add(2**a * 3**b * 5**c * 7**d * 11**1 * 13**0)

l = [x for x in l if x >= 20 and x <= 1000]  # filter out the really small ones
print(sorted(l))

[20, 21, 22, 24, 25, 26, 27, 28, 30, 32, 33, 35, 36, 39, 40, 42, 44, 45, 48, 49, 50, 52, 54, 55, 56, 60, 63, 64, 65, 66, 70, 72, 75, 77, 78, 80, 81, 84, 88, 90, 91, 96, 98, 99, 100, 104, 105, 108, 110, 112, 117, 120, 125, 126, 128, 130, 132, 135, 140, 144, 147, 150, 154, 156, 160, 162, 165, 168, 175, 176, 180, 182, 189, 192, 195, 196, 198, 200, 208, 210, 216, 220, 224, 225, 231, 234, 240, 243, 245, 250, 252, 256, 260, 264, 270, 273, 275, 280, 288, 294, 297, 300, 308, 312, 315, 320, 324, 325, 330, 336, 343, 350, 351, 352, 360, 364, 375, 378, 384, 385, 390, 392, 396, 400, 405, 416, 420, 432, 440, 441, 448, 450, 455, 462, 468, 480, 486, 490, 495, 500, 504, 512, 520, 525, 528, 539, 540, 546, 550, 560, 567, 576, 585, 588, 594, 600, 616, 624, 625, 630, 637, 640, 648, 650, 660, 672, 675, 686, 693, 700, 702, 704, 720, 728, 729, 735, 750, 756, 768, 770, 780, 784, 792, 800, 810, 819, 825, 832, 840, 864, 875, 880, 882, 891, 896, 900, 910, 924, 936, 945, 960, 972, 975, 980, 990, 1000]

Could you try out this list of numbers instead? The docs indicate that they should work faster.

dkobak commented 3 years ago

Sure I can use this -- it's probably better. It's just that as you can see from my yesterday's experiments, values like 201, 211, 213, etc. work slower (sometimes much slower) than 250. So I thought that fine grid is not really necessary here. But it certainly won't hurt either!

Also, let's maybe cap it at 1000, so if the number of boxes wants to be above 1000, I'll just set it to 1000...

dkobak commented 3 years ago

Here is how it looks:

runtime-oscillations-fix2

Does not look like a staircase anymore.

pavlin-policar commented 3 years ago

Hmm, I think this looks better, no? There are still a few spikes, but those might be explained away due to other things your system is doing. Anyways, does this fix your problem?

I have played around a bit with changing the number of interpolation points, but I think the current setting is very reasonable. For example, if I fixed the number of intervals to e.g. 50 and have the intervals expand with the embedding, the final embedding would exhibit banding at the box boundaries. 100 intervals seemed to work fine, but I let it run for a while so the embedding would have a larger span, and I saw some banding there as well. Interestingly enough, the banding was more severe in the standard perplexity based affinities than uniform affinities (or it could be the other way around -- I forgot).

My feeling is that we might be able to play with the number of intervals and fine tune it, but the gains would be marginal. The current setting seems pretty good.

dkobak commented 3 years ago

I agree.

Yes, the PR does fix the problem.