Closed bhsimon0810 closed 3 days ago
Yes, we enhance the model’s adaptability to data by dynamically integrating task-specific knowledge (task vectors) into the model parameters based on the input. This process requires computation at each forward pass, potentially increasing both the time and memory cost compared to static methods.
However, since the additional computation mainly involves linear parameter interpolation and does not involve matrix multiplication, and the primary model only performs inference once, this time cost is not very high.
For example, use this codebase and with CLIP-ViT-B/32, it takes 15-20 minutes to perform 1000 steps of test-time adaptation training on two RTX 3090 GPUs.
In practice, we have some optimizations:
if self.gate.num_hidden_layers == 0:
self.merge_weights(gate_weights)
output_hidden_states = self.forward_model(hidden_states)
elif self.batch_reduce: # we set this to True during the TTA phase
gate_weights = gate_weights.mean(dim=0)
self.merge_weights(gate_weights)
output_hidden_states = self.forward_model(hidden_states)
else: # when perform inference, `batch_reduce` is set to False
output_hidden_states = []
for sample_idx, weights in enumerate(gate_weights):
self.merge_weights(weights)
if self.batch_first:
output_hidden_states.append(
self.forward_model(hidden_states[sample_idx : sample_idx + 1])
)
else:
output_hidden_states.append(
self.forward_model(
hidden_states[:, sample_idx : sample_idx + 1]
)
)
if self.batch_first:
output_hidden_states = torch.cat(output_hidden_states, dim=0)
else:
output_hidden_states = torch.cat(output_hidden_states, dim=1)
Thanks for your clarification! Thanks again for your great work! Close the issue :)
Hi,
Thank you for your great work! I am a bit confused about the time cost. I noticed in your code that each data sample is iterated to dynamically calculate the merged weights. This seems like it could cause a relatively large time cost. Could you provide some clarification on this?