wzcai99 / Pixel-Navigator

Official GitHub Repository for Paper "Bridging Zero-shot Object Navigation and Foundation Models through Pixel-Guided Navigation Skill", ICRA 2024
61 stars 3 forks source link

Release of training code #7

Open FlightVin opened 1 month ago

FlightVin commented 1 month ago

Hi,

Could you release the code (implented loss funcions, specific optimizers, other parameters, etc.) that you folks used for training the models?

Thanks!

wzcai99 commented 1 month ago

The training code is quite simple, we just follow the DDP frameworks in pytorch to train the policy. And we list some hyper-parameters below: optimizer = torch.optim.Adam(ddp_model.parameters(),lr=1e-4) scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1,total_iters=100000) For loss functions, as illustrated in the paper, it contains three elements which are action-loss, distance-loss and tracking-loss. The first using torch.nn.function.cross_entropy loss while others use mse_loss. we use 4 Geforce 4090 for training, with the batchsize as 32 in total.

FlightVin commented 1 month ago

Hi,

I wanted some further clarity on how the model is being trained.

Consider the following train function

    def train(self, dataloader: PixelNavDataloader, num_epochs=10):
        num_batches = len(dataloader) // self.batch_size

        # Optimizer and scheduler
        self.optimizer = optim.Adam(self.pixel_nav_model.parameters(), lr=1e-4)
        self.scheduler = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1.0, end_factor=0.1, total_iters=num_batches*num_epochs)

        for epoch in range(num_epochs):
            epoch_total_loss = 0.0
            epoch_distance_loss = 0.0
            epoch_goal_loss = 0.0
            epoch_action_loss = 0.0

            # Create random batched indices for this epoch
            indices = np.random.permutation(len(dataloader))

            for batch_idx in range(num_batches):
                batch_indices = indices[batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size]

                # Get batch data from the dataloader
                data = [dataloader.get_traj_at_idx(idx) for idx in batch_indices]
                goal_image_batch, goal_mask_batch, episode_images_batch, \
                action_preds_ground_truth_batch, distance_preds_ground_truth_batch, \
                goal_preds_ground_truth_batch = zip(*data)

                # Convert lists to tensors
                goal_image_batch = torch.stack(goal_image_batch)
                goal_mask_batch = torch.stack(goal_mask_batch)
                episode_images_batch = torch.stack(episode_images_batch)
                action_preds_ground_truth_batch = torch.tensor(action_preds_ground_truth_batch)
                distance_preds_ground_truth_batch = torch.tensor(distance_preds_ground_truth_batch)
                goal_preds_ground_truth_batch = torch.tensor(goal_preds_ground_truth_batch)

                # Train the model on the batch
                losses = self.train_batch(
                    goal_image_batch=goal_image_batch,
                    goal_mask_batch=goal_mask_batch,
                    episode_images_batch=episode_images_batch,
                    action_preds_ground_truth_batch=action_preds_ground_truth_batch,
                    distance_preds_ground_truth_batch=distance_preds_ground_truth_batch,
                    goal_preds_ground_truth_batch=goal_preds_ground_truth_batch
                )

                epoch_total_loss += losses["total_loss"]
                epoch_distance_loss += losses["distance_loss"]
                epoch_goal_loss += losses["goal_loss"]
                epoch_action_loss += losses["action_loss"]

                # Scheduler step after each batch
                self.scheduler.step()

Is the training loop apt? Further, am I using the scheduler as intended?

wzcai99 commented 1 month ago

Yes, in our implementations, the scheduler will be updated at each step backpropagation.