Closed rosikand closed 1 year ago
Here is basic try at this:
class MeanMetric:
"""
Scalar metric designed to use on a per-epoch basis
and updated on per-batch basis. For getting average across
the epoch.
"""
def __init__(self):
self.vals = []
def update(self, new_val):
if torch.is_tensor(new_val):
new_val = new_val.item()
self.vals.append(new_val)
def reset(self):
self.vals = []
def get(self):
mean_value = sum(self.vals)/len(self.vals)
return mean_value
class Experiment(ABC):
"""
Base experiment superclass. All other experiments
should inherit from this class. Each sub-experiment
must provide an implementation of the "evaluate" abstract
method. A sub-experiment has full autonomy to override
the basic components such as the training loop "train".
"""
def __init__(self, model, optimizer, trainloader, wandb_logger=None, verbose=False):
"""
Experiment superclass initializer. Each subclass must provide
a model, optimizer, and trainloader at the very least.
Arguments:
-----------
- model: torch nn.module
- optimizer: torch optimizer
- trainloader: torch Dataloader to be used for training
"""
self.model = model
self.optimizer = optimizer
self.trainloader = trainloader
self.wandb_logger = wandb_logger
self.verbose = verbose
self.first_batch = True
def train(self, num_epochs):
"""
Training loop.
"""
self.model.train()
epoch_num = 0
metrics = {}
for epoch in range(num_epochs): # loop over the dataset num_epochs times
epoch_num += 1
tqdm_loader = tqdm(self.trainloader)
for batch in tqdm_loader:
tqdm_loader.set_description(f"Epoch {epoch_num}")
self.optimizer.zero_grad()
evals = self.evaluate(batch)
if self.first_batch:
assert type(evals) is dict
assert "loss" in evals.keys(), "evaluate must return a 'loss' value"
# register metrics
for key in evals:
curr_metric = MeanMetric()
metrics[key] = curr_metric
self.first_batch = False
loss = evals["loss"]
loss.backward()
self.optimizer.step()
# update metrics
for key in metrics:
metrics[key].update(evals[key])
if self.verbose:
for key in metrics:
print(f"{key}: {metrics[key].get()}")
if self.wandb_logger is not None:
for key in metrics:
self.wandb_logger.log({str(key): metrics[key].get()})
self.model.eval()
print('Finished Training!')
but the problem is that this only works for mean metrics on a per-epoch basis (i.e., logging accuracy wouldn't work). And so it seems that this is a bit harder of a problem than it looks. One viable option would be to force the user to provide the metric classes but this goes against the simplicity of the package. But of course, we could provide common metric classes (i.e., accuracy, IoU, etc.).
I think it might actually make more sense to just provide do_this_on_batch_end
and do_this_on_epoch_end
functions for which users can just keep track of the state themselves. Also provide MeanMetric
class in utils.py
Update: I think I will push this change sometime soon and have evaluate
enforce that a dictionary is returned rather than a numerical loss. But the dict must contain the loss.
With regards to:
but the problem is that this only works for mean metrics on a per-epoch basis (i.e., logging accuracy wouldn't work). I disagree now... accuracy would still work. Even with a 1 batch size basis, if you input a bunch of 100's and 0's, the average should still be the mean accuracy!
Also, check out torchmetrics which basically follows the (get
, reset
, update
) class structure,
In addition, provide metrics.py
:
class MeanMetric:
"""
Scalar metric designed to use on a per-epoch basis
and updated on per-batch basis. For getting average across
the epoch.
"""
def __init__(self):
self.vals = []
def update(self, new_val):
self.vals.append(new_val)
def reset(self):
self.vals = []
def get(self):
mean_value = sum(self.vals)/len(self.vals)
return mean_value
class MeanMetricCustom(ABC):
"""
Abstract scalar metric. Must provide calculation given preds and y.
"""
def __init__(self):
self.vals = []
@abstractmethod
def calculate(self, logits, y):
# returns a value
pass
def update(self, logits, y):
self.vals.append(self.calculate(logits,y))
def reset(self):
self.vals = []
def get(self):
mean_value = sum(self.vals)/len(self.vals)
return mean_value
# then provide some common metrics such
Metrics to provide: Each are given logits, y as input.
Implemented in version 0.0.7.
Instead of having
evaluate
only provide the scalar loss value, have it provide a dict containing at least theloss
but can also contain other things likeaccuracy
which are logged (printed to console) during training. Also add ahistory
state dict to the superclass which keeps track of the values on a per-epoch basis. That way, a user can callexp.history['loss']
can get the loss for each epoch.