pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.84k stars 489 forks source link

GradientShap needs internal_batch_size argument to avoid out-of-memory errors #1350

Open princyok opened 1 week ago

princyok commented 1 week ago

🐛 Bug

GradientShap (captum.attr.GradientShap.attribute), which is an extension of Integrated Gradients, needs an internal_batch_size argument just like IntegratedGradients.

Currently, using any large value for n_samples results in out-of-memory errors, because the input is stacked n_samples times. The same kind of issue is already fixed in IntegratedGradients via the internal_batch_size argument.