rosikand / torchplate

🍽 A minimal and simple experiment module for machine learning research workflows in PyTorch.
https://rosikand.github.io/torchplate/
MIT License
3 stars 0 forks source link

Add ability to log multiple metrics in `train` and add history state #1

Closed rosikand closed 1 year ago

rosikand commented 2 years ago

Instead of having evaluate only provide the scalar loss value, have it provide a dict containing at least the loss but can also contain other things like accuracy which are logged (printed to console) during training. Also add a history state dict to the superclass which keeps track of the values on a per-epoch basis. That way, a user can call exp.history['loss'] can get the loss for each epoch.

rosikand commented 2 years ago

Here is basic try at this:

class MeanMetric:
  """
  Scalar metric designed to use on a per-epoch basis 
  and updated on per-batch basis. For getting average across
  the epoch. 
  """

  def __init__(self):
    self.vals = []

  def update(self, new_val):
    if torch.is_tensor(new_val):
        new_val = new_val.item()
    self.vals.append(new_val)

  def reset(self):
    self.vals = []

  def get(self):
    mean_value = sum(self.vals)/len(self.vals)
    return mean_value

class Experiment(ABC):
    """
    Base experiment superclass. All other experiments
    should inherit from this class. Each sub-experiment
    must provide an implementation of the "evaluate" abstract
    method. A sub-experiment has full autonomy to override
    the basic components such as the training loop "train". 
    """
    def __init__(self, model, optimizer, trainloader, wandb_logger=None, verbose=False):
        """
        Experiment superclass initializer. Each subclass must provide
        a model, optimizer, and trainloader at the very least. 
        Arguments: 
        -----------
        - model: torch nn.module 
        - optimizer: torch optimizer 
        - trainloader: torch Dataloader to be used for training 
        """
        self.model = model 
        self.optimizer = optimizer
        self.trainloader = trainloader
        self.wandb_logger = wandb_logger
        self.verbose = verbose
        self.first_batch = True

    def train(self, num_epochs):
        """
        Training loop. 
        """
        self.model.train()
        epoch_num = 0

        metrics = {}

        for epoch in range(num_epochs):  # loop over the dataset num_epochs times 
            epoch_num += 1

            tqdm_loader = tqdm(self.trainloader)
            for batch in tqdm_loader:
                tqdm_loader.set_description(f"Epoch {epoch_num}")
                self.optimizer.zero_grad()
                evals = self.evaluate(batch)
                if self.first_batch:
                    assert type(evals) is dict
                    assert "loss" in evals.keys(), "evaluate must return a 'loss' value"
                    # register metrics
                    for key in evals:
                        curr_metric = MeanMetric()
                        metrics[key] = curr_metric
                    self.first_batch = False 
                loss = evals["loss"]
                loss.backward()
                self.optimizer.step()
                # update metrics 
                for key in metrics:
                    metrics[key].update(evals[key])

            if self.verbose:
                for key in metrics:
                    print(f"{key}: {metrics[key].get()}")
            if self.wandb_logger is not None:
                for key in metrics:
                    self.wandb_logger.log({str(key): metrics[key].get()})

        self.model.eval()
        print('Finished Training!')

but the problem is that this only works for mean metrics on a per-epoch basis (i.e., logging accuracy wouldn't work). And so it seems that this is a bit harder of a problem than it looks. One viable option would be to force the user to provide the metric classes but this goes against the simplicity of the package. But of course, we could provide common metric classes (i.e., accuracy, IoU, etc.).

rosikand commented 2 years ago

I think it might actually make more sense to just provide do_this_on_batch_end and do_this_on_epoch_end functions for which users can just keep track of the state themselves. Also provide MeanMetric class in utils.py

rosikand commented 2 years ago

Update: I think I will push this change sometime soon and have evaluate enforce that a dictionary is returned rather than a numerical loss. But the dict must contain the loss.

With regards to:

but the problem is that this only works for mean metrics on a per-epoch basis (i.e., logging accuracy wouldn't work). I disagree now... accuracy would still work. Even with a 1 batch size basis, if you input a bunch of 100's and 0's, the average should still be the mean accuracy!

rosikand commented 1 year ago

Also, check out torchmetrics which basically follows the (get, reset, update) class structure,

rosikand commented 1 year ago

In addition, provide metrics.py:

class MeanMetric:
  """
  Scalar metric designed to use on a per-epoch basis 
  and updated on per-batch basis. For getting average across
  the epoch. 
  """

  def __init__(self):
    self.vals = []

  def update(self, new_val):
    self.vals.append(new_val)

  def reset(self):
    self.vals = []

  def get(self):
    mean_value = sum(self.vals)/len(self.vals)
    return mean_value

class MeanMetricCustom(ABC):
  """
  Abstract scalar metric. Must provide calculation given preds and y. 
  """

  def __init__(self):
    self.vals = []

  @abstractmethod
  def calculate(self, logits, y):
    # returns a value
    pass

  def update(self, logits, y):    
    self.vals.append(self.calculate(logits,y))

  def reset(self):
    self.vals = []

  def get(self):
    mean_value = sum(self.vals)/len(self.vals)
    return mean_value

# then provide some common metrics such

Metrics to provide: Each are given logits, y as input.

rosikand commented 1 year ago

Implemented in version 0.0.7.