iancovert / fastshap

An amortized approach for calculating local Shapley value explanations
MIT License
86 stars 17 forks source link

remove for loop from uniform sampler #4

Closed szvsw closed 9 months ago

szvsw commented 1 year ago

Was just reading through the code to get a better understanding of FastSHAP (in prep for the long email I sent you @iancovert!) and came across this todo which seemed like a quick fix.

The idea is to first generate this matrix, which is (num_players + 1) x (num_players), representing each of the possible mask densities:

[
    [0, 0, 0, ... , 0],
    [1, 0, 0, ... , 0],
    [1, 1, 0, ... , 0],
    [1, 1, 1, ... , 0],
      .   .   .   .    .
    [1, 1, 1, ... , 1]
]

Then, we create a new matrix in which each member of the batch randomly selects one of the mask density rows, resulting in a batch_size x num_players matrix.

Then, we generate a permutation of indices 0,1,..,num_players-1 for each row. Finally we use those permutation indices to re-arrange each row to get our final mask permutations for each batch element.

Not a super significant PR, but it was simple and quick so I figured I'd try to knock it out!

iancovert commented 1 year ago

This looks reasonable! Just to be sure, were you able to test that it runs successfully (at least for a couple epochs)?

szvsw commented 1 year ago

Not yet! Been away from my GPU. I just did a/b testing with the outputs of the original Module and the updated Module.

For testing, I figured I will run the census notebook with the original and the new version, then compare runtimes as well as results?

I can also put together a little benchmark test for the Module in isolation to see if the runtime gains are really worth it. Obviously at larger batch sizes it will be, but at low batch sizes but large number of players it possibly might not be more efficient (though I expect it still will be).

I will let you know after I run a full test.

szvsw commented 1 year ago

I ran tests in isolation, gridding over the number of players in the sampler and the number of samples to take. For each pair of n_players and n_samples I sampled 100 times, taking the average of the run times to make the following plot:

image

Definitely seems worth it! Obviously lots of time is spent doing other things in the surrogate training loop, so I ran some benchmarks there as well. Ran it until early stopping executed each time. As expected, larger batch sizes realize better relative performance for the new version.

Batch Size Sampler Version Epochs Time (s) Final Loss s/epoch Performance Factor
32 Old 83 352 0.1247 4.241 n/a
32 New 70 207 0.1251 2.957 30.3%
64 Old 63 197 0.1258 3.127 n/a
64 New 89 175 0.1244 1.966 37.2%
128 Old 70 178 0.1256 2.542 n/a
128 New 70 97 0.1312 1.386 44.5%
szvsw commented 1 year ago

Looks like you were on to this approach here:

https://github.com/iancovert/fastshap/blob/f5fbc52bc046bd3501463e2c8f46ab9a44a6fab5/fastshap/utils.py#L161-L195

It probably makes sense to implement a similar approach for the ShapleySampler with the necessary adaptations to handle the complementary sampling, since obviously much more time is spent in the FasthSHAP training loop. Maybe that can be a separate PR though.

szvsw commented 1 year ago

I pushed a change which adapts the ShapleySampler to avoid for loops while also still allowing paired sampling.

I ran the whole notebook.

The new version of the FastSHAP training loop took 140s for 56 epochs, 2.6 s/epoch, stopping at a best loss of 0.058353. Results comparison with kernelShap below:

image

The old version of the FastShap training loop took 1,057s for 57 epochs, 18.6 s/epoch, stopping at a best loss of 0.059535. Results comparison with kernelShap below (using the same example index as in the test with the new training loop).

image

Seems like a pretty productive performance gain!

iancovert commented 1 year ago

This is great, thanks for being so thorough! You're right that the ShapleySampler is the more important one to improve, and it does look like I tried something similar (and forgot to port it to the UniformSampler!) but didn't fully optimize it. A couple thoughts:

class UniformSampler:
    '''
    For sampling player subsets with cardinality chosen uniformly at random.

    Args:
      num_players: number of players.
    '''

    def __init__(self, num_players):
        self.num_players = num_players

    def sample(self, batch_size):
        '''
        Generate sample.

        Args:
          batch_size:
        '''
        rand = torch.rand(batch_size, self.num_players)
        ref = torch.rand(batch_size, 1)
        return (rand > ref).float()
szvsw commented 1 year ago

Aha! I originally wanted to do a random-vs-threshold comparator to achieve the permutations, but couldn't figure out how to get the correct distribution over the different sizes of S... the trick is to use 1 for the second dimension of the thresholder, not the num_players... clever! Eg we equally choose 20% threshold, 30% threshold, 40% threshold etc for each row - cool!

That approach should be much more performant than mine since it doesn't rely on argsort like mine does. I had to rely on argsort because randperm doesn't have an axis/dim arg. This means that my approach has complexity that goes with the complexity of argsort, which in turn goes with the num_players. This is why you can see in my isolated performance test graphic, the new version will start to catch up to the old version 1 the old version had no time penalty for the num_players (or at least very minimal) but the new version does.

Switching to the random/threshold comparator approach should resolve this.

This also primarily explains the performance differences you were observing.

So in summary, before we merge this, I will try out the random threshold comparator approach, which we expected should be even better than the approach in the PR currently.

In re: converting samplers to nn.Modules - yeah I was thinking about that as well. These aren't very large operations so I don't know how beneficial it will be... maybe I will just add a hard coded .to(device) in for testing purposes to see if GPU operations help a significant amount.

I'll probably be able to do this stuff sometime over the weekend or Monday.

szvsw commented 1 year ago

UniformSampler performance in Feature Removal Surrogate training loop table, updated (still using census, so num_features = 12). As you can see, the random/thresh tester performs significantly better than the tril/gather approach, because it does not have to do any argsort.

Batch Size Sampler Version Epochs Time (s) Final Loss s/epoch Performance Factor
32 Old 83 352 0.1247 4.241 n/a
32 New (Tril/Gather) 70 207 0.1251 2.957 30.3%
32 New (Thresh) 89 236 0.1230 2.618 38.3%
64 Old 63 197 0.1258 3.127 n/a
64 New (Tril/Gather) 89 175 0.1244 1.966 37.2%
64 New (Thresh) 52 94 0.1276 1.808 42.2%
128 Old 70 178 0.1256 2.542 n/a
128 New (Tril/Gather) 70 97 0.1312 1.386 44.5%
128 New (Thresh) 76 98 0.1299 1.289 50.7%

To further confirm this, and to confirm the behavior you saw, I ran a new test where I simply duplicated the columns of the census data to get up to 192 features and ran the training loop again, only at a batch size of 128. As you can see, compared to the num_features=12 results above, the old version had almost identical performance compared to num_features=12 as expected (no argsort/feature size complexity penalty, only batch-size complexity penalty). The tril/gather approach suffered a significant penalty compared to its num_features=12 test due to the increased sorting complexity, while the random/thresh approach only had a very minor change.

Batch Size Sampler Version Epochs Time (s) Final Loss s/epoch Performance Factor
128 Old 63 160 - 2.540 n/a
128 New (Tril/Gather) 78 163 - 2.090 17.8%
128 New (Thresh) 52 75 - 1.442 43.3%

This definitely suggests a major advantage for the random/thresh approach (as expected), so I committed it for the UniformSampler.

The ShapleySampler will take a little extra thinking to make sure it only generates rows with 1,,,.num_features-1 ones, but I'll come up with something.

iancovert commented 1 year ago

Well this looks great for the UniformSampler, seems like we know which version to go with! And very nice to see expected behavior wrt the number of players, mainly the larger penalty when we require $O(n \log n)$ sorting.

Re: the other bits:

So what do you think, is this good to merge? Or is there anything else to do first? Thanks again for the attention you put into this!

szvsw commented 1 year ago

Yeah I got it close to working for the Shapley sampler, but it's as you mentioned not generating quite the correct distribution over number of features selected, since the approach I took essentially sets the thresholds for n-2 of the samples according to the selected number of features for that row; forces the remaining two features to be 0,1 to avoid the grand coalition / null coalition from being generated, and then does a random rotation. It gets the shape mostly correct up to the few bins on either edge. As you said, it's an issue with the Bernoulli distributions overlapping.

I also added much more thorough benchmarks in a notebook for the Shapley sampler, including GPU tests, in the "benchmark" folder which I uploaded (along with the summary graphics).

There are also some graphics that verify that the distribution of number of features selected is correct (except in the attempt at a "rand/thresh" approach for Shapley sampler).

I was going to do a little write up tomorrow.

The gist is that:

iancovert commented 1 year ago

I just spent some time thinking about the thresholding idea for the ShapleySampler, specifically how to derive the correct PMF from a BetaBinomial distribution. It's impossible to make it work with a single threshold, because we can't prevent sampling either 0 or d players. I thought about a rejection sampling approach, where we ignore rows with 0 or d players, and I don't think that will work either (at least with a Beta-distributed threshold). I also thought about whether we could use the thresholding trick over d - 2 players, so that we could concatenate a 0 and 1 to ensure we avoid the grand/null coalitions, but this didn't seem possible either (plus it might defeat the purpose, since we'll need to shuffle in the 0/1). I see you got close by setting the threshold based on the num_included sample, but yeah weirdly it's off at the edges. Anyway, this is tricky enough that we should maybe set it aside for now. Your notebook is quite helpful though.

Let me know when you're happy with your writeup, I see you mentioned that in your last message.

And I guess the last thing is, I see you switched to numpy for the permutation in ShapleySampler rather than using torch.gather? I see in your benchmarking plots that it's faster (NumpyPermuteCPU vs ArgSortCPU). I wonder if we could get the same speed in pure PyTorch by replacing the indices = torch.argsort(torch.rand_like(S), dim=-1) line with torch.randperm - argsorting shouldn't be necessary to generate a random ordering.

szvsw commented 1 year ago

Yeah, at the end of the day, a lot of our problems/challenges with this are because PyTorch doesn't have great support for random sampling/shuffling/rotating etc with independent along across an axis. As I've been diving into this over the past week or two, it turns out there are lots of threads online/feature requests for more parity with numpy.random.choice/shuffle etc.

I also thought about whether we could use the thresholding trick over d - 2 players, so that we could concatenate a 0 and 1 to ensure we avoid the grand/null coalitions, but this didn't seem possible either (plus it might defeat the purpose, since we'll need to shuffle in the 0/1).

Yeah, the way I handled this was... interesting, not sure if you saw. The idea was to make the first two columns 0/1, the rest of the columns controlled by the threshold determined per row based off the desired number of players for that row, and then instead of shuffling each row, we do a simpler rotation independently for each row - essentially by generating range(n) for each row, adding a random int from 0...n-1 separately to each range(n), and then taking the modulus to generate rotation indices, then using gather to execute the rotations.

I see you got close by setting the threshold based on the num_included sample, but yeah weirdly it's off at the edges. Anyway, this is tricky enough that we should maybe set it aside for now. Your notebook is quite helpful though.

Yeah, I was wondering if it would be possible to update the weights for the original Categorical sample distribution to get this to work... essentially just redistributing the extra mass in 1, n-1 categories to the other categories to bring them back up to where they are supposed to be... but it seemed like it would require pretty careful thought for those new weights, which I wasn't up for lol...

And I guess the last thing is, I see you switched to numpy for the permutation in ShapleySampler rather than using torch.gather? I see in your benchmarking plots that it's faster (NumpyPermuteCPU vs ArgSortCPU). I wonder if we could get the same speed in pure PyTorch by replacing the indices = torch.argsort(torch.rand_like(S), dim=-1) line with torch.randperm - argsorting shouldn't be necessary to generate a random ordering.

The problem with randperm is that it doesn't have the ability to generate multiple permutations simultaneously - it only accepts an integer argument. The examples in my notebook with the MultiNomial sampling approach are the closest you can get in PyTorch to generating permutation indices per row directly without an argsort, at least that I've come across.

I was trying to think if there was some sort of approach where you do something like this:

num_ixs = batch_size * num_players
shuffled_ixs = torch.randperm(num_ixs)
shuffled_ixs = shuffled_ixs.reshape(batch_size, num_players)

This returns a matrix like this...

[
  [ 3,  2,  7],
  [12,  9,  8],
  [ 6, 14, 13],
  [ 5, 10,  1],
  [ 0, 11,  4]
]

The ends up reducing to the same problem as the argsort approach with random numbers - i.e. we need to figure out how [3,2,7] becomes [1,0,2] or how [12,9,8] becomes [2,1,0].

It took a bit of work, but I did end up figuring out a method for converting this into permutation matrices / [0...n-1] row index permutations using just pytorch ops / no for loops / no argsort by leveraging the integer nature, but it actually ended up being way slower smh...

At the end of the day, I think the numpy approach is the best. It takes advantage of proper compiled code which is designed to do exactly what we want.


I did some tests with the numpy method using both the census notebook and the cifar notebook.

For the census notebook, there are the expected large performance improvements compared to the original, since the number of players is so low AND since the network is so small, the time spent in forward/backward is small relative to the time spent generating samples.

In the cifar notebook, there were almost no difference between using the original, argsort gpu, and numpy mode, presumably because the time spent in forward and backward is much much greater relative to the time spent sampling. I tried with a few different selections for the num_samples and batch_size and this seemed to hold.


At the end of the day, the lesson here is that when the ratio of time spent in forward/backward to the time spent sampling is low, this can yields large performance improvements (i.e. sampling is a bottleneck), but when it's high, the effects are negligible.

I'm going to delete the benchmarking folder so it doesn't clutter the main repo and just move it to another branch on my fork. Then I would say this is ready to merge!

iancovert commented 1 year ago

Yeah the ShapleySampler distribution is tricky, I wonder if it's even possible to use an efficient thresholding trick... But that's a problem for another day.

I see what you mean about the lack of PyTorch/numpy parity for shuffling. I just looked through some PyTorch discussion board posts about independently shuffling rows of a tensor - I remember looking into this before and I guess there's still no great solution. As a last-ditch effort to stick with pure PyTorch, do you think we could match the numpy version's speed by generating separate permutations with something like indices = torch.stack([torch.randperm(d) for _ in range(n)])? We might pay a price for having to generate the permutations sequentially, if so I'm good to merge in the current version.

szvsw commented 1 year ago

I'm pretty sure that approach would be the same as the old one essentially since it is just trading a for loop for a list comprehension.

Having said that, it's worth testing, and, as one last thing that might be worth trying if we want to try to stick with pure (kinda) torch...

Using TorchScript might get us back to what we want, maybe better if we are lucky, and in a fairly readable way too, eg:

@torch.jit.script
def permute_rows(S, num_players, batch_size):
  for i in range(batch_size):
    ixs = torch.randperm(num_players, dtype=torch.int16)
    S[i] = S[i, ixs]
  return S

This is essentially identical to the original approach, but since the function is now jit'd it shouldn't incur the python interpreter penalty for the for loop. It may still not be quite as performant as the numpy version but worth testing.

iancovert commented 1 year ago

The one potential speedup is using torch.gather to apply the reorderings in parallel, vs my original approach which applies them sequentially. Using numpy isn't so bad though, we could also just go with that.