tanganke / weight-ensembling_MoE

Code for paper "Merging Multi-Task Models via Weight-Ensembling Mixture of Experts"
5 stars 1 forks source link

Question about Time Cost #3

Closed bhsimon0810 closed 3 days ago

bhsimon0810 commented 4 days ago

Hi,

Thank you for your great work! I am a bit confused about the time cost. I noticed in your code that each data sample is iterated to dynamically calculate the merged weights. This seems like it could cause a relatively large time cost. Could you provide some clarification on this?

tanganke commented 4 days ago

Yes, we enhance the model’s adaptability to data by dynamically integrating task-specific knowledge (task vectors) into the model parameters based on the input. This process requires computation at each forward pass, potentially increasing both the time and memory cost compared to static methods.

However, since the additional computation mainly involves linear parameter interpolation and does not involve matrix multiplication, and the primary model only performs inference once, this time cost is not very high.

For example, use this codebase and with CLIP-ViT-B/32, it takes 15-20 minutes to perform 1000 steps of test-time adaptation training on two RTX 3090 GPUs.

In practice, we have some optimizations:

  1. This procedure is independent for each input data sample, making it easy to parallelize. In our implementation, we use DDP training.
  2. To reduce GPU memory usage and accelerate test-time adaptation training, we calculate the merged weights at the batch level instead of for each individual input sample during TTA phase. We implement this in FusionBench. In FusionBench, we use some models fine-tuned by ourselves, but the results are comparable.
if self.gate.num_hidden_layers == 0:
    self.merge_weights(gate_weights)
    output_hidden_states = self.forward_model(hidden_states)
elif self.batch_reduce: # we set this to True during the TTA phase
    gate_weights = gate_weights.mean(dim=0)
    self.merge_weights(gate_weights)
    output_hidden_states = self.forward_model(hidden_states)
else:  # when perform inference, `batch_reduce` is set to False
    output_hidden_states = []
    for sample_idx, weights in enumerate(gate_weights):
        self.merge_weights(weights)
        if self.batch_first:
            output_hidden_states.append(
                self.forward_model(hidden_states[sample_idx : sample_idx + 1])
            )
        else:
            output_hidden_states.append(
                self.forward_model(
                    hidden_states[:, sample_idx : sample_idx + 1]
                )
            )
    if self.batch_first:
        output_hidden_states = torch.cat(output_hidden_states, dim=0)
    else:
        output_hidden_states = torch.cat(output_hidden_states, dim=1)
bhsimon0810 commented 3 days ago

Thanks for your clarification! Thanks again for your great work! Close the issue :)