ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.15k stars 5.61k forks source link

[data] Ray does not warn of as yet undetermined resource depletion #42551

Open Nintorac opened 8 months ago

Nintorac commented 8 months ago

What happened + What you expected to happen

I am trying to train a model using a dataset that is consumed and transformed through ray data. The training runs fine for a period of time and will then freeze, killing ray workers can help to unfreeze it but not for long.

Memory and disk seem fine.

I had a hunch file descriptors may be involved and see a spike from 30k to 150k after launching the repro. Using the commented 'render_batch' function resulted in 100k open files and that seems to run without issue indefinitely. Not really sure if that's related or not

Versions / Dependencies

Python 3.11.2

librosa==0.10.1
pedalboard==0.8.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pytorch-lightning==2.0.4
ray==2.9.0
torch==2.1.2
torchaudio==2.1.2
torchdata==0.7.1
torchmetrics==1.3.0
torchtext==0.16.2
torchvision==0.16.2
tqdm==4.66.1

Reproduction script

First download and extract Dexed.vst3 from here

then in the same folder create this

#%%
from functools import partial
from itertools import count
from time import sleep
from librosa import resample
import ray
import torch
from tqdm import tqdm
import pyarrow as pa
from pedalboard import load_plugin
batch_size = 128
RENDER_SAMPLE_RATE = 44100

def float_to_pcm16(audio):
    import numpy

    ints = ((audio + 1) * 32767).astype(numpy.int32)

    return ints

def to_midi(*args, **kwargs): 
    return []

def render_batch(notes, sr: int=8000, duration: float=2.5, bpm: int=120)->pa.array:
    chunk = notes
    # Load a VST3 or Audio Unit plugin from a known path on disk:

    instrument = load_plugin('Dexed.vst3')

    samples = []
    # t = tqdm()

    to_midi_f = partial(to_midi, bpm=bpm)
    for notes in map(to_midi_f, chunk):
        # import time
        # from datetime import datetime
        # import logging
        # logger=logging.getLogger('lol')
        # logger.setLevel(logging.INFO)
        # st = time.time()
        # print(f'starting... {datetime.now()})')

        x = instrument(
            notes,
            duration=duration, # seconds
            sample_rate=RENDER_SAMPLE_RATE,
        )
        # print(time.time()-st)
        x = resample(x, orig_sr=RENDER_SAMPLE_RATE, target_sr=sr)

        samples.append(float_to_pcm16(x.mean(0)))
        # t.update(1)

    # 1/0
    return pa.array(samples)

# def render_batch(notes, sr: int=8000, duration: float=2.5, bpm: int=120)->pa.array:
#     return pa.array((torch.zeros((int(sr * duration),)).numpy() for i in notes))

if __name__=='__main__':
    ds = ray.data.range(10000, parallelism=8)

    def f(batch):
        # 2.5s audio @8bit, 8k sample rate 
        x = render_batch([[]]*len(batch['id']))
        x = torch.tensor(x.to_pylist())
        return {"x": x, "y": x.unsqueeze(-1)}

    ds = ds.map_batches(f, zero_copy_batch=True, batch_size=128, num_cpus=1)
    for i in tqdm(count(), desc='epoch'):
        c=count()
        print()
        for batch in tqdm(ds.iter_torch_batches(batch_size=batch_size, prefetch_batches=1), desc='batch'):
            if next(c)>20:
                break

# %%

then to reproduce

  1. start ray, ive been using ray stop --force && ray start --head --num-cpus 0 --include-dashboard=true && ray start --address localhost:6379 but it was happening either way
  2. run the script
  3. wait for a few epochs ~10 or so usually for me
  4. observe no more batches rolling

Hardware

Issue Severity

High: It blocks me from completing my task.

Nintorac commented 7 months ago

FWIW I think this is a bug in pedalboard, but I think it would be good to get some log warnings in these situations