fgnt / padertorch

A collection of common functionality to simplify the design, training and evaluation of machine learning models based on pytorch with an emphasis on speech processing.
MIT License
71 stars 16 forks source link
audio pytorch speech


Build Status Azure DevOps tests Azure DevOps coverage License: MIT

Padertorch is designed to simplify the training of deep learning models written with PyTorch. While focusing on speech and audio processing, it is not limited to these application areas.

Schematic overview of padertorch



$ git clone https://github.com/fgnt/padertorch.git
$ cd padertorch && pip install -e .[all]

This will install all dependencies. For a light installation, you can drop [all].


Getting Started

A Short Explanation of padertorch.Module and padertorch.Model

You can build your models upon padertorch.Module and padertorch.Model. Both expect a forward method which has the same functionality as the forward call of torch.nn.Module: It takes some data as input, applies some transformations, and returns the network output:

class MyModel(pt.Module):

  def forward(self, example):
      x = example['x']
      out = transform(x)
      return out

Additionally, padertorch.Model expects a review method to be implemented which takes the input and output of the forward call as its inputs from which it computes the training loss and metrics for logging in tensorboard. The following is an example for a classification problem using the cross-entropy loss:

import torch

class MyModel(pt.Model):

  def forward(self, example):
      output = ...  # e.g., output has shape (N, C), where C is the number of classes
      return output

  def review(self, example, output):
      # loss computation, where example['label'] has shape (N,)
      ce_loss = torch.nn.CrossEntropyLoss()(output, example['label'])
      # compute additional metrics
      with torch.no_grad():
          prediction = torch.argmax(output, dim=1)
          accuracy = (prediction == example['label']).float().mean()
      return {
          'loss': ce_loss,
          'scalars': {'accuracy': accuracy}

See padertorch.summary.tbx_utils.review_dict for how the review dictionary should be constructed. For each training step, the trainer calls forward, passes its output to review and performs a backpropagation step on the loss. Typically, the input to the forward of a Module is a Tensor, while for a Model, it is a dictionary which contains additional entries, e.g., labels, which are needed in the review. This is only a recommendation and there is no restriction for the input type.

While these two methods are mandatory, you are free to add any further methods to your models. Since a Module does not need a review method, it can be used as a component of a Model.

How to Integrate your Data and Model with the Trainer

The trainer works with any kind of iterable, e.g., list, torch.utils.data.DataLoader or lazy_dataset.Dataset. The train method expects an iterable as input which yields training examples or minibatches of examples that are forwarded to the model without being interpreted by the trainer, i.e., the yielded entries can have any data type and only the model has to be designed to work with them. In our examples, the iterables always yield a dict.

The Model implements an example_to_device which is called by the trainer to move the data to a CPU or GPU. Per default, example_to_device uses padertorch.data.example_to_device which recursively converts numpy arrays to Tensors and moves all Tensors to the available device. The training device can be directly provided to the call of Trainer.train. Otherwise, it is automatically set by the trainer according to torch.cuda.is_available.

Optionally, you can add an iterable with validation examples by using Trainer.register_validation_hook. Some functionalities (e.g., keeping track of the best checkpoint) are then performed on the validation data.

A simple sketch for the trainer setup is given below:

import torch
import padertorch as pt

train_dataset = ...
validation_dataset = ...

class MyModel(pt.Model):
    def __init__(self):
        self.net = torch.nn.Sequential(...)

    def forward(self, example):
        output = self.net(example['observation'])
        return output

    def review(self, example, output):
        loss = ...  # calculate loss
        with torch.no_grad():
            ...  # calculate general metrics
            if self.training:
                ...  # calculate training specific metrics
                ...  # calculate validation specific metrics
        return {
            'loss': loss,
            'scalars': {
                'accuracy': ...,
        }  # Furthers keys: 'images', 'audios', 'histograms', 'texts', 'figures'

trainer = padertorch.Trainer(
    storage_dir=pt.io.get_new_storage_dir('my_experiment'),  # checkpoints of the trained model are stored here
    summary_trigger=(1, 'epoch'),
    checkpoint_trigger=(1, 'epoch'),
    stop_trigger=(1, 'epoch'),
trainer.test_run(train_dataset, validation_dataset)

See the trainer for an explanation of its signature. If you want to use pt.io.get_new_storage_dir to manage your experiments, you have to define an environment variable STORAGE_ROOT which points to the path where all your experiments will be stored, i.e., in the example above, a new directory under $STORAGE_ROOT/my_experiment/1 will be created. Otherwise, you can use pt.io.get_new_subdir where you can directly input the path to store your model without defining an environment variable.

Features for Application in Deep Learning

Padertorch provides a selection of frequently used network architectures and functionalities such as activation and normalization, ready for you to integrate into your own models.

Support for Sequential and Speech Data

Padertorch especially offers support for training with sequential data such as:

We also provide a loss wrapper for permutation-invariant training (PIT) criteria which are, e.g., commonly used in (speech) source separation.

Further Reading

Have a look at the following links to get the most out of your experience with padertorch: