Open xrsrke opened 8 months ago
- from nanotron.core.parallelism.tensor_parallelism.nn import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelLinearMode, - TensorParallelRowLinear, - ) - from nanotron.core.optimizer.zero import ZeroDistributedOptimizer - from nanotron.dataloaders.nemo import get_nemo_dataloader + from nanotron.distributed import ParallelContext, ParallelMode + from nanotron.nn.tensor_parallel import ColumnParallelLinear, RowParallelLinear, ParallelEmbedding, ParallelCrossEntropy + from nanotron.nn.pipeline_parallel import PipelineBlock + from nanotron.nn.data_parallel import DataParallel + from nanotron.optim import ZeroDistributedOptimizer + from nanotron.utils.data import DistributedDataLoader - dpg = get_process_groups( - data_parallel_size=self.config.parallelism.dp, - pipeline_parallel_size=self.config.parallelism.pp, - tensor_parallel_size=self.config.parallelism.tp, - ) + parallel_context = ParallelContext.from_torch( + tensor_parallel_size=2, + pipeline_parallel_size=4, + data_parallel_size=2 + ) class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - dpg: DistributedProcessGroups, parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext ): super().__init__() # Declare all the nodes - self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda")) self.config = config self.parallel_config = parallel_config - self.dpg = dpg self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE tp_linear_async_communication = ( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - self.token_position_embeddings = PipelineBlock( - p2p=self.p2p, - module_builder=Embedding, - module_kwargs={ - "tp_pg": dpg.tp_pg, - "config": config, - "parallel_config": parallel_config, - }, - module_input_keys={"input_ids", "input_mask"}, - module_output_keys={"input_embeds"}, - ) + token_position_embeddings = Embedding(config, parallel_config , parallel_context) + self.token_position_embeddings = PipelineBlock( + token_position_embeddings, parallel_context, + input_keys={"input_ids", "input_mask"}, output_keys={"input_embeds"} + ) - self.decoder = nn.ModuleList( - [ - PipelineBlock( - p2p=self.p2p, - module_builder=LlamaDecoderLayer, - module_kwargs={ - "config": config, - "parallel_config": parallel_config, - "tp_pg": dpg.tp_pg, - "layer_idx": layer_idx, - }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) + # user specify how many transformer blocks does a rank has (since this is quite simple) +. num_local_pipeline_stages = .... + decoder = nn.ModuleList([LlamaDecoderLayer(config, layer_idx, parallel_config , parallel_context) for layer_idx in range(num_local_pipeline_stages)]) + self.decoder = PipelineBlock( + final_layer_norm, parallel_context, + input_keys={"hidden_states", "sequence_mask"}, + output_keys={"hidden_states", "sequence_mask"} + ) - self.final_layer_norm = PipelineBlock( - p2p=self.p2p, - module_builder=RMSNorm, - module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) + final_layer_norm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.final_layer_norm = PipelineBlock( + final_layer_norm, parallel_context, + input_keys={"input"}, output_keys={"hidden_states"} + ) - self.lm_head = PipelineBlock( - p2p=self.p2p, - # Understand that this means that we return sharded logits that are going to need to be gathered - module_builder=TensorParallelColumnLinear, - module_kwargs={ - "in_features": config.hidden_size, - "out_features": config.vocab_size, - "pg": dpg.tp_pg, - "bias": False, - "mode": self.tp_mode, - "async_communication": tp_linear_async_communication, - }, - module_input_keys={"x"}, - module_output_keys={"logits"}, - ) + lm_head = ColumnParallelLinear( + config.hidden_size, config.vocab_size, + bias=False, mode=self.tp_mode, + async_communication=tp_linear_async_communication + ) + self.lm_head = PipelineBlock( + lm_head, parallel_context, + input_keys={"x"}, output_keys={"logits"} + ) - self.cast_to_fp32 = PipelineBlock( - p2p=self.p2p, - module_builder=lambda: lambda x: x.float(), - module_kwargs={}, - module_input_keys={"x"}, - module_output_keys={"output"}, - ) def forward( self, - input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] - input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_ids: torch.Tensor, # [batch_size, seq_length] + input_mask: torch.Tensor, # [batch_size, seq_length] ): # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) hidden_encoder_states = { "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits, hidden_states -class LlamaForTraining(BRRRModel): - def __init__( - self, - config: LlamaConfig, - dpg: DistributedProcessGroups, - parallel_config: Optional[ParallelismArgs], - random_states: Optional[RandomStates] = None, - ): - super().__init__() - self.model = LlamaModel(config=config, dpg=dpg, parallel_config=parallel_config) - self.loss = PipelineBlock( - p2p=self.model.p2p, - module_builder=Loss, - module_kwargs={"tp_pg": dpg.tp_pg}, - module_input_keys={ - "sharded_logits", - "label_ids", - "label_mask", - }, - module_output_keys={"loss"}, - ) - self.dpg = dpg - self.config = config - self.parallel_config = parallel_config - def forward( - self, - input_ids: Union[torch.Tensor, TensorPointer], - input_mask: Union[torch.Tensor, TensorPointer], - label_ids: Union[torch.Tensor, TensorPointer], - label_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - sharded_logits = self.model( - input_ids=input_ids, - input_mask=input_mask, - ) - loss = self.loss( - sharded_logits=sharded_logits, - label_ids=label_ids, - label_mask=label_mask, - )["loss"] - return {"loss": loss} - model = init_model( - model_builder=lambda: LlamaForTraining(config=model_config, dpg=dpg, parallel_config=parallel_config), - model_config=model_config, - parallel_config=parallel_config, - dtype=dtype, - dpg=dpg, - make_ddp=False, - ) + model = LlamaModel(config, parallel_context) # we eliminate `sync_gradients_across_dp`, `DataParallel` automatically register backward hooks + model = DataParallel(model, parallel_context) - outputs = pipeline_engine.train_batch_iter( - model=model, - pg=dpg.pp_pg, - batch=(next(data_iterator) for _ in range(n_micro_batches_per_batch)), - nb_microbatches=n_micro_batches_per_batch, - grad_accumulator=grad_accumulator, -) + model = PiplineParallel(model, num_microbatches, parallel_context) - optimizer, grad_accumulator = init_optimizer_and_grad_accumulator( - model=model, optimizer_args=optimizer_args, dpg=dpg - ) + named_parameters = ... + optimizer = ZeroDistributedOptimizer(named_parameters, parallel_context) - dataloader = get_nemo_dataloader( - dataset=train_dataset, - sequence_length=sequence_length, - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - num_workers=config.data.num_loading_workers, - cfg=config.data.dataset, - consumed_samples=consumed_train_samples, - dpg=dpg, - input_pp_rank=input_pp_rank, - output_pp_rank=output_pp_rank, - dataloader_drop_last=True - ) # assume that only the first pipeline stage loads data, # subsequent pipeline stages only receives activations + if parallel_context.is_first_rank(ParallelMode.PIPELINE): + dataloader = DistributedDataLoader( + dataset, sequence_length, microbatch_size, global_batch_size, + num_workers, consumed_samples, dataloader_drop_last + parallel_context + ) + for _ in range(epochs): + for batch in dataloader: + outputs = model(batch) + + # assume that only the last pipeline stage has the loss + if parallel_context.get_local_rank(ParallelMode.PIPELINE) == parallel_context.pipeline_parallel_size: + loss = ParallelCrossEntropy(outputs["logits"], targets) # this is sharded logits + optimizer.zero_grad() + loss.backward() + optimizer.step()