Open a1342772 opened 1 month ago
class Trainer: def __init__(self, global_rank, gpu_id: int, trainer_config: TrainerConfig, model: RecNet, optimizer, world_size: int, data_cfg: DataConfig): self.global_rank = global_rank self.config = trainer_config self.world_size = world_size self.dataloader = Data(data_cfg) self.epochs_run = 0 self.gpu_id = gpu_id self.model = model.to(gpu_id) self.optimizer = optimizer if self.config.use_amp: self.scaler = torch.cuda.amp.GradScaler() # load snapshot if available. only necessary on the first node. if self.config.snapshot_path is None: self.config.snapshot_path = "snapshot.pt" self._load_snapshot() # wrap with DDP. this step will synch model across all the processes. # self.model = torch.compile(DDP(self.model, device_ids=[gpu_id])) torch.set_float32_matmul_precision('high') def _load_snapshot(self): try: snapshot = fsspec.open(self.config.snapshot_path) with snapshot as f: snapshot_data = torch.load(f, map_location="cpu") except FileNotFoundError: print("Snapshot not found. Training model from scratch") return snapshot = Snapshot(**snapshot_data) self.model.load_state_dict(snapshot.model_state) self.optimizer.load_state_dict(snapshot.optimizer_state) self.epochs_run = snapshot.finished_epoch print(f"Resuming training from snapshot at Epoch {self.epochs_run}") def cal_loss(self, score, labels): click_loss = F.binary_cross_entropy(score["click_score"], labels['click_label']) add_loss = F.binary_cross_entropy(score["add_score"], labels['add_label']) order_loss = F.binary_cross_entropy(score["add_order_score"], labels['order_label']) loss = self.config.click_loss_weight * click_loss + \ self.config.add_loss_weight * add_loss + self.config.order_loss_weight * order_loss return loss def _run_batch(self, features, labels, train: bool = True): with torch.set_grad_enabled(train), torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=self.config.use_amp): score = self.model(features) loss = self.cal_loss(score, labels) if train: self.optimizer.zero_grad(set_to_none=True) if self.config.use_amp: self.scaler.scale(loss).backward() if self.config.use_clip_grad: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm_clip) self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() if self.config.use_clip_grad: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm_clip) self.optimizer.step() return loss.item() def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True): with self.model.join(): for _iter, (features, labels) in enumerate(dataloader): features = {feat_name: torch.as_tensor(data=feat_data, dtype=torch.int, device=self.gpu_id) for feat_name, feat_data in features.items()} labels = {label_name: torch.as_tensor(data=label_data, dtype=torch.float, device=self.gpu_id) for label_name, label_data in labels.items()} step_type = "Train" if train else "Eval" batch_loss = self._run_batch(features, labels, train) if _iter % 100 == 99: logging.info( f"{datetime.datetime.now()} [GPU{self.gpu_id}] Epoch {epoch} | Iter {_iter} | {step_type} Loss {batch_loss:.5f}") def _eval(self, step, epoch, is_eval): logging.info("eval start") self.dataloader.config.data_file = self.dataloader.config.test_file test_loader = self.dataloader.data_input_fn_torch_eval() self.model.eval() with torch.no_grad(): # 禁用梯度计算 loss = 0.0 _iter = 0 click_scores = [] add_scores = [] order_scores = [] click_labels = [] add_labels = [] order_lebels = [] for _iter, (features, labels) in enumerate(test_loader): features = {feat_name: torch.as_tensor(data=feat_data, dtype=torch.int, device=self.gpu_id) for feat_name, feat_data in features.items()} labels = {label_name: torch.as_tensor(data=label_data, dtype=torch.float, device=self.gpu_id) for label_name, label_data in labels.items()} score = self.model(features, is_eval) batch_loss = self.cal_loss(score, labels) click_scores.append(score["click_score"].detach().cpu().numpy()) add_scores.append(score["add_score"].detach().cpu().numpy()) order_scores.append(score["add_order_score"].detach().cpu().numpy()) click_labels.append(labels['click_label'].detach().cpu().numpy()) add_labels.append(labels['add_label'].detach().cpu().numpy()) order_lebels.append(labels['order_label'].detach().cpu().numpy()) loss += batch_loss click_auc = roc_auc_score(np.concatenate(click_labels), np.concatenate(click_scores)) add_auc = roc_auc_score(np.concatenate(add_labels), np.concatenate(add_scores)) order_auc = roc_auc_score(np.concatenate(order_lebels), np.concatenate(order_scores)) logging.info( f'''{datetime.datetime.now()} [GPU{self.gpu_id}] Epoch {epoch} | Iter {step} | Eval Loss {loss / _iter:.5f}, click_auc:{click_auc}, add_auc:{add_auc}, order_auc:{order_auc}''') send_msg("product-feeds-category-rank-model-v62:: click_auc: {}, add_auc: {}, order_auc: {}".format( click_auc, add_auc, order_auc)) if add_auc < 0.7: send_msg("@majun26 product-feeds-category-rank-model-v62:: add_auc: {} is too low".format(add_auc)) sys.exit(1) def _save_checkpoint(self, epoch=1): state_dict = self.model.cpu().state_dict() unwanted_prefix = '_orig_mod.module.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): if k.startswith(unwanted_prefix + 'pos_logit'): state_dict.pop(k) else: state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) PATH = f'{self.config.checkpoint_path}/checkpoint.pt' torch.save(state_dict, PATH) logging.info(f"Epoch {epoch} | Training checkpoint saved at {PATH}") def train(self): self.dataloader.config.data_file = self.dataloader.config.train_file train_loader = self.dataloader.data_input_fn_torch_train(self.world_size, self.global_rank) for i in range(self.config.max_epochs): self._run_epoch(i, train_loader, train=True) def eval(self, is_eval): self._eval(1, 1, is_eval)
class Optimizer: def __init__(self, model: RecNet, opt_config: OptimizerConfig): self.model = model self.opt_config = opt_config def get(self): if self.opt_config.name.lower() == 'adam': return self.adam() if self.opt_config.name.lower() == 'adamw': return self.adamw() def adam(self): return torch.optim.Adam(params=self.model.parameters(), lr=self.opt_config.learning_rate) def adamw(self): return torch.optim.AdamW(params=self.model.parameters(), lr=self.opt_config.learning_rate)
@anw90 @qiuxiafei 大佬们,帮忙看看。