BobMcDear / attorch

A subset of PyTorch's neural network modules, written in Python using OpenAI's Triton.
MIT License
486 stars 22 forks source link
cuda deep-learning machine-learning openai openai-triton pytorch triton

attorch

Introduction
Installation
Layers
Math Functions
PyTorch Fallback
Tests

Introduction

attorch is a subset of PyTorch's nn module, written purely in Python using OpenAI's Triton. Its goal is to be an easily hackable, self-contained, and readable collection of neural network modules whilst maintaining or improving upon the efficiency of PyTorch. In other words, it intends to be a forkable project endowed with a simple, intuitive design that can serve as an accessible starting point for those who are seeking to develop custom deep learning operations but are not satisfied with the speed of a pure PyTorch implementation and do not have the technical expertise or resources to write CUDA kernels.

There already exist a number of wonderful PyTorch-like frameworks powered by Triton, including kernl, xFormers, Unsloth, and fla, but most concentrate mainly on Transformers and NLP applications, whereas attorch aims to be more inclusive by also presenting a variety of layers pertaining to areas besides NLP such as computer vision. Moreover, attorch is not an inference-only package and fully supports both forward and backward passes, meaning it can be used during training as well as inference, though its performance for the latter is generally not on par with dedicated inference engines.

Installation

The only dependencies of attorch are torch==2.4.0 and triton==3.0.0. Please install the specified versions of these two libraries and clone this repository to get started.

Layers

Currently implemented layers, with automatic mixed precision (AMP) support, are,

Unless otherwise noted in their docstrings, the aforementioned layers behave identically to their PyTorch equivalents.

Math Functions

Triton kernels are generally composed of two parts: One handles the loading and storing of the relevant tensors, the other transforms the data using appropriate mathematical functions. For instance, a layer normalization kernel reads one or several rows from the input (load), standardizes the features (math), and writes the results into a container (store). A selection of these pure math functions is supplied by attorch.math, the objective being to faciliate the implementation of custom kernels and operation fusion. Although only the forward passes of the said functions are available in attorch.math, thanks to their purity and absence of I/O actions, their gradients can be automatically derived via the triton-autodiff library. Significant portions of attorch's kernels can be refactored by supplanting their math bits with the corresponding attorch.math transformations or their derivatives, but doing so would sacrifice the single-file and self-contained design of attorch, so attorch.math and the rest of attorch will remain separate.

PyTorch Fallback

To enable easier integration of attorch and PyTorch layers, attorch.nn is offered, which provides an interface to attorch's modules with PyTorch fallback should a desired layer not be available, as seen below.

from attorch import nn

lin = nn.Linear(10, 20) # Uses attorch's linear layer
gap = nn.AdaptiveAvgPool2d(1) # Uses PyTorch's global pooling since GAP is not available in attorch

Tests

Each module can be tested against its PyTorch counterpart to ensure correctness. These tests are included under tests/ and can be executed using pytest. It should be noted that some might fail owing to numerical precision issues, but in most practical use cases, that should not be a problem.