Closed Spongeorge closed 9 months ago
Hi, PASTA should not incur high memory footprint because it only registered a forward_pre_hook to change the attention_mask
of attention module inputs (see here). The forward pass is not changed and there is no additional parameters introduced. From our empirical results, it show the same memory consumption as baseline models.
I confirmed that with small sized inputs like the one in the readme there is no noticeable memory consumption increase.
For my application I'm running very large sized inputs (~4k tokens) and there does appear to be a moderate increase in memory usage (from ~1gb at 1 head steered to ~10gb at 1024 heads steered), but now it seems to me that this isn't an issue with the PASTA code and more likely has to do with something in a lower level library.
Thanks for your quick response.
I was able to trace this issue back to this bit of code in the edit_attention_mask method:
if head_dim != self.num_attn_head:
attention_mask = attention_mask.expand(
bsz, self.num_attn_head, tgt_len, src_len
).clone()
Specifically, it seems that cloning the attention mask causes a very slight memory leak that is only apparent with huge input sizes. Removing the .clone()
fixes this, but produces a deprecation warning:
UserWarning: Use of index_put_ on expanded tensors is deprecated. Please clone() the tensor before performing this operation. This also applies to advanced indexing e.g. tensor[indices] = tensor (Triggered internally at ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:713.)
@Spongeorge I believe removing .clone()
would lead to errors since in-place operations for different heads would then modify data in the same memory location (as expand
only changes the view of the tensor and does not make a copy of it). The cloning is unfortunately necessary since PASTA needs to maintain a separate copy of attention mask for each head, in-order to intervene the heads independently.
Hi, I'm hitting OOM exceptions on a 40GB A100 when trying to steer more than a few heads at a time.
When steering 32 heads (an entire layer in my case) I use almost all of the GPU memory, and when steering just a single head I use very little memory.
Is this memory usage inherent PASTA's algorithm?