hpcaitech / TensorNVMe

A Python library transfers PyTorch tensors between CPU and NVMe
93 stars 18 forks source link

[RFC] Notify memory collection and callback at `async_read` and `async_write` #35

Closed ofey404 closed 2 years ago

ofey404 commented 2 years ago

What's new

We can write callback more elegantly, like this:

# Out of order overlap
for i, tensor in enumerate(tensors):
    def compute_then_offload():
        tensor.mul_(2.0)
        offloader.async_write(tensor)

    offloader.async_read(tensor, callback=compute_then_offload)

offloader.synchronize()

Changes behind the scene

The change is simple but powerful: we add a wait() in each async_read/write call.

void Offloader::async_write(const at::Tensor &tensor, const std::string &key, callback_t callback)
{
    _async_write_nowait(tensor, key, callback);

    // Notify pending memory collection and callback.
    this->aio->wait();
}

To visualize what will happen:


      ┌──────────────────────────────────┐
      │                                  │
      │             ┌──────────────────┐ │   ┌───────────────────┐
      │             │                  │ │   │  Callback 1 & 2   │
      │             │                  │ │   │  will be picked   │
      │             │                  │ │   │  up by            │
      │             │                  │ │   │  async_write3     │
      │             │                  │ │   └────────┬──────────┘
      │             │                  │ │            │
  ────┴─────────────┴──────────────────▼─▼────────────┴─────►
 async_write1    async_write2         done        async_write3

Background

In async data structures, it's a common practice to notify other threads while accessing methods.

Take python's queue.Queue as example:

# python3.9/queue.py
class Queue:
    def put(self, item, block=True, timeout=None):
            ...
            self._put(item)
            ...
            self.not_empty.notify()
    def get(self, block=True, timeout=None):
            ...
            item = self._get()
            self.not_full.notify()
            return item

Example and Test

Manual test script:

import torch
from tensornvme import DiskOffloader

offloader = DiskOffloader('./offload')

tensors = []

for _ in range(10):
    tensor = torch.rand(2, 2)
    tensors.append(tensor)
    print(tensor)
    offloader.sync_write(tensor)

print("== computing! ==")

# Out of order overlap
for i, tensor in enumerate(tensors):
    def compute_then_offload():
        tensor.mul_(2.0)
        offloader.async_write(tensor)

    offloader.async_read(tensor, callback=compute_then_offload)

offloader.synchronize()

for tensor in tensors:
    offloader.sync_read(tensor)
    print(tensor)

Expected output:

tensor([[0.9309, 0.7670],
        [0.6096, 0.1153]])
...
== computing! ==
tensor([[1.8618, 1.5340],
        [1.2191, 0.2306]])
...