Open panjianfei opened 4 months ago
why subtract self.vocab_start_index
def forward(self, input_): assert not torch.any( (input_ < 0) | (input_ >= self.num_embeddings) ), "An input token is out of bounds of the embedding table" if self.tensor_model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. output_parallel = self.weight[masked_input] # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel) return output
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L225
Marking as stale. No activity in 60 days.
why subtract self.vocab_start_index
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L225