AlibabaPAI / torchacc

PyTorch distributed training acceleration framework
Apache License 2.0
32 stars 3 forks source link

使用self.model = torchacc.accelerate(model)加速模型训练的时,GPU利用率极低,接近于0 #25

Open a1342772 opened 1 month ago

a1342772 commented 1 month ago
企业微信截图_5fbb6659-8da7-4421-967c-d884913a70d1 企业微信截图_89d5d92b-b56f-4757-bda8-e83ce8ccff2e 企业微信截图_dce90b55-29c5-45f7-8a59-7086454975f6
class Trainer:
    def __init__(self, global_rank, gpu_id: int, trainer_config: TrainerConfig, model: RecNet, optimizer,
                 world_size: int, data_cfg: DataConfig):
        self.global_rank = global_rank
        self.config = trainer_config
        self.world_size = world_size
        self.dataloader = Data(data_cfg)
        self.epochs_run = 0
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.optimizer = optimizer
        if self.config.use_amp:
            self.scaler = torch.cuda.amp.GradScaler()
        # load snapshot if available. only necessary on the first node.
        if self.config.snapshot_path is None:
            self.config.snapshot_path = "snapshot.pt"
        self._load_snapshot()
        # wrap with DDP. this step will synch model across all the processes.
        # self.model = torch.compile(DDP(self.model, device_ids=[gpu_id]))
        torch.set_float32_matmul_precision('high')

    def _load_snapshot(self):
        try:
            snapshot = fsspec.open(self.config.snapshot_path)
            with snapshot as f:
                snapshot_data = torch.load(f, map_location="cpu")
        except FileNotFoundError:
            print("Snapshot not found. Training model from scratch")
            return

        snapshot = Snapshot(**snapshot_data)
        self.model.load_state_dict(snapshot.model_state)
        self.optimizer.load_state_dict(snapshot.optimizer_state)
        self.epochs_run = snapshot.finished_epoch
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def cal_loss(self, score, labels):
        click_loss = F.binary_cross_entropy(score["click_score"], labels['click_label'])
        add_loss = F.binary_cross_entropy(score["add_score"], labels['add_label'])
        order_loss = F.binary_cross_entropy(score["add_order_score"], labels['order_label'])

        loss = self.config.click_loss_weight * click_loss + \
            self.config.add_loss_weight * add_loss + self.config.order_loss_weight * order_loss
        return loss

    def _run_batch(self, features, labels, train: bool = True):
        with torch.set_grad_enabled(train), torch.amp.autocast(device_type="cuda", dtype=torch.float16,
                                                               enabled=self.config.use_amp):
            score = self.model(features)
            loss = self.cal_loss(score, labels)
        if train:
            self.optimizer.zero_grad(set_to_none=True)
            if self.config.use_amp:
                self.scaler.scale(loss).backward()
                if self.config.use_clip_grad:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm_clip)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                if self.config.use_clip_grad:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_norm_clip)
                self.optimizer.step()
        return loss.item()

    def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
        with self.model.join():
            for _iter, (features, labels) in enumerate(dataloader):
                features = {feat_name: torch.as_tensor(data=feat_data, dtype=torch.int, device=self.gpu_id)
                            for feat_name, feat_data in features.items()}
                labels = {label_name: torch.as_tensor(data=label_data, dtype=torch.float, device=self.gpu_id)
                          for label_name, label_data in labels.items()}
                step_type = "Train" if train else "Eval"
                batch_loss = self._run_batch(features, labels, train)
                if _iter % 100 == 99:
                    logging.info(
                        f"{datetime.datetime.now()} [GPU{self.gpu_id}] Epoch {epoch} | Iter {_iter} | {step_type} Loss {batch_loss:.5f}")

    def _eval(self, step, epoch, is_eval):
        logging.info("eval start")
        self.dataloader.config.data_file = self.dataloader.config.test_file
        test_loader = self.dataloader.data_input_fn_torch_eval()

        self.model.eval()
        with torch.no_grad():  # 禁用梯度计算
            loss = 0.0
            _iter = 0
            click_scores = []
            add_scores = []
            order_scores = []
            click_labels = []
            add_labels = []
            order_lebels = []
            for _iter, (features, labels) in enumerate(test_loader):
                features = {feat_name: torch.as_tensor(data=feat_data, dtype=torch.int, device=self.gpu_id)
                            for feat_name, feat_data in features.items()}
                labels = {label_name: torch.as_tensor(data=label_data, dtype=torch.float, device=self.gpu_id)
                          for label_name, label_data in labels.items()}
                score = self.model(features, is_eval)
                batch_loss = self.cal_loss(score, labels)
                click_scores.append(score["click_score"].detach().cpu().numpy())
                add_scores.append(score["add_score"].detach().cpu().numpy())
                order_scores.append(score["add_order_score"].detach().cpu().numpy())
                click_labels.append(labels['click_label'].detach().cpu().numpy())
                add_labels.append(labels['add_label'].detach().cpu().numpy())
                order_lebels.append(labels['order_label'].detach().cpu().numpy())
                loss += batch_loss
            click_auc = roc_auc_score(np.concatenate(click_labels), np.concatenate(click_scores))
            add_auc = roc_auc_score(np.concatenate(add_labels), np.concatenate(add_scores))
            order_auc = roc_auc_score(np.concatenate(order_lebels), np.concatenate(order_scores))
            logging.info(
                f'''{datetime.datetime.now()} [GPU{self.gpu_id}] Epoch {epoch} | Iter {step} | Eval Loss {loss / _iter:.5f}, 
                click_auc:{click_auc}, add_auc:{add_auc}, order_auc:{order_auc}''')

            send_msg("product-feeds-category-rank-model-v62:: click_auc: {}, add_auc: {}, order_auc: {}".format(
                click_auc, add_auc, order_auc))
            if add_auc < 0.7:
                send_msg("@majun26 product-feeds-category-rank-model-v62:: add_auc: {} is too low".format(add_auc))
                sys.exit(1)

    def _save_checkpoint(self, epoch=1):
        state_dict = self.model.cpu().state_dict()
        unwanted_prefix = '_orig_mod.module.'
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                if k.startswith(unwanted_prefix + 'pos_logit'):
                    state_dict.pop(k)
                else:
                    state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

        PATH = f'{self.config.checkpoint_path}/checkpoint.pt'
        torch.save(state_dict, PATH)
        logging.info(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self):
        self.dataloader.config.data_file = self.dataloader.config.train_file
        train_loader = self.dataloader.data_input_fn_torch_train(self.world_size, self.global_rank)
        for i in range(self.config.max_epochs):
            self._run_epoch(i, train_loader, train=True)

    def eval(self, is_eval):
        self._eval(1, 1, is_eval)
class Optimizer:

    def __init__(self, model: RecNet, opt_config: OptimizerConfig):
        self.model = model
        self.opt_config = opt_config

    def get(self):
        if self.opt_config.name.lower() == 'adam':
            return self.adam()
        if self.opt_config.name.lower() == 'adamw':
            return self.adamw()

    def adam(self):
        return torch.optim.Adam(params=self.model.parameters(),
                                lr=self.opt_config.learning_rate)

    def adamw(self):
        return torch.optim.AdamW(params=self.model.parameters(),
                                 lr=self.opt_config.learning_rate)
a1342772 commented 1 month ago

@anw90 @qiuxiafei 大佬们,帮忙看看。