Open agberman opened 4 years ago
Hello @agberman,
I would guess you are keeping references to images in some list somewhere and this is preventing them from being freed, but of course I can't tell for sure :(
Unfortunately, I don't think I can help unless you can make a small, complete program which shows the problem. I would make a copy of your code and then repeatedly cut away chunks.
Rather than extract_area
(also available as crop
), I would use the new fetch API. You should see a dramatic speed-up. There's a sample program here: https://github.com/libvips/pyvips/issues/100#issuecomment-493960943
Hi @jcupitt,
I have created a small and complete program to show the continual memory growth:
import sys
import torch
import numpy as np
import pyvips as pv
import PIL.Image as Image
import torch.utils.data as data
import torchvision.transforms as transforms
from random import randint
def main():
pv.cache_set_max(0)
# Simulate lib of 269 WSIs from one WSI
wsipath = '/path/to/tumor_001.tif'
lib = {'slides': [], 'grid': []}
maxwidth = pv.Image.new_from_file(wsipath, level=0, access="sequential").width - 300
maxheight = pv.Image.new_from_file(wsipath, level=0, access="sequential").height - 300
for i in range(269):
lib["slides"].append(wsipath) # normally these would all be paths to different WSIs
tiles = []
for j in range(20000): # choose 20,000 random tiles per slide
tiles.append((randint(300, maxwidth), randint(300, maxheight)))
lib["grid"].append(tiles)
normalize = transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.1,0.1,0.1])
trans = transforms.Compose([transforms.ToTensor(), normalize])
train_dset = MILdataset(lib, trans)
train_loader = torch.utils.data.DataLoader(
train_dset,
batch_size=512, shuffle=False,
num_workers=8, pin_memory=True)
for epoch in range(100):
probs = inference(epoch, train_loader)
def inference(run, loader):
with torch.no_grad():
for i, input in enumerate(loader):
print('Inference\tEpoch: [{}/{}]\tBatch: [{}/{}]'.format(run+1, 100, i+1, len(loader)))
return True
class MILdataset(data.Dataset):
def __init__(self, libraryfile='', transform=None):
lib = libraryfile
slides = []
for i,name in enumerate(lib['slides']):
sys.stdout.write('Opening SVS headers: [{}/{}]\r'.format(i+1, len(lib['slides'])))
sys.stdout.flush()
slides.append(pv.Image.new_from_file(name, level=0, access="sequential"))
print('')
grid = []
slideIDX = []
for i,g in enumerate(lib['grid']):
grid.extend(g)
slideIDX.extend([i]*len(g))
print('Number of tiles: {}'.format(len(grid)))
self.slides = slides
self.grid = grid
self.slideIDX = slideIDX
self.transform = transform
def __getitem__(self,index):
slideIDX = self.slideIDX[index]
coord = self.grid[index]
tmp_img = self.slides[slideIDX].extract_area(coord[0], coord[1], 224, 224)
np_img = np.ndarray(buffer=tmp_img.write_to_memory(), dtype=np.uint8, shape=[tmp_img.height, tmp_img.width, tmp_img.bands])
img = Image.fromarray(np_img).convert('RGB')
img = self.transform(img)
return img
def __len__(self):
return len(self.grid)
if __name__ == '__main__':
main()
Here is a link to the pyramidal tiff used by this program (tumor_001.tif): https://drive.google.com/open?id=0BzsdkU4jWx9BQnFwak9PbGtBVUk
All that needs to be done is that the line wsipath = '/path/to/tumor_001.tif'
needs to be updated to where this tiff is saved on your computer and the code should run.
In practice I am using 269 different pyramidal tiffs (WSIs) like this one, but I have written the above program so that it just uses this same tiff 269 times, each with a unique set of tile coordinates (in the program above generated randomly but in practice these are the foreground tiles which actually show tissue in them). The number of tiles per slide that are randomly generated (20,000) roughly corresponds to how many foreground tiles I can expect from one WSI.
Oh well done! I'll have a look, thanks.
Also I should add that lib
is a dictionary which has two lists in it, both of length 269: one is called slides
and one is called grid
. slides contains a list of 269 paths to WSIs (here, it is just the same path to the same WSI copied 269 times, but in practice they are unique WSIs), and grid
is actually a list of 269 sublists. Each sublist contains 20,000 tuples of randomly generated coordinates. Each sublist of tile coordinate tuples corresponds by index to the WSI paths in slides
(i.e. the 20,000 tuples at lib["grid"][150]
are all 20,000 tiles that we want to extract to train on from the WSI stored at lib["slides"][150]
).
Also note that the inference
function in the code above would normally also forward-pass each batch through the model, but here it only iterates over batches and does nothing with those batches (no need to complicate things by training an actual model when the issue can be seen just from iterating over the batches using pytorch's enumerate(loader)
). The inference
function would normally return a tensor of probabilities corresponding to all tiles, but instead it just returns True
here. Also, you should change num_workers=8
to be however many cores/CPUs you have to work with.
Another thing I should mention is that the memory growth is steady (even though memory takes small dips when batches are finished the general trend is an upward creep) but quite slow. For example, I'm currently running that script, and it's on epoch 1, batch 713, and about 15.5gb of memory are in use. At the beginning of the epoch, about 7 to 8gb were used. Growth at this pace continues roughly linearly as far as I can tell. I should probably make a graph. Update: now on epoch 1, batch 2208 and 28gb of memory are in use.
I tried a very quick pyvips-only version:
#!/usr/bin/python3
import sys
import pyvips
import random
import PIL.Image as Image
import numpy as np
import os
import psutil
process = psutil.Process(os.getpid())
image = pyvips.Image.new_from_file(sys.argv[1])
for i in range(20000):
if i % 1000 == 0:
mb = process.memory_info().rss / (1024 * 1024)
print(f"iteration {i}, {mb} MB ...")
tile = image.crop(random.randint(300, image.width - 300),
random.randint(300, image.height - 300),
224, 224)
np_img = np.ndarray(buffer=tile.write_to_memory(),
dtype=np.uint8,
shape=[tile.height, tile.width, tile.bands])
img = Image.fromarray(np_img).convert('RGB')
So just fetching tiles. You don't need sequential
for tiled tiff images, since they support true random access. I see:
$ ./leak4.py ~/pics/openslide/tumor_001.tif
iteration 0, 57.5234375 MB ...
iteration 1000, 149.71484375 MB ...
iteration 2000, 156.71875 MB ...
iteration 3000, 155.71875 MB ...
iteration 4000, 152.83984375 MB ...
iteration 5000, 160.71484375 MB ...
iteration 6000, 148.90234375 MB ...
iteration 7000, 162.90234375 MB ...
iteration 8000, 150.09375 MB ...
iteration 9000, 154.09375 MB ...
iteration 10000, 157.29296875 MB ...
iteration 11000, 154.48828125 MB ...
iteration 12000, 146.80859375 MB ...
iteration 13000, 148.125 MB ...
iteration 14000, 159.07421875 MB ...
iteration 15000, 157.1484375 MB ...
iteration 16000, 154.3359375 MB ...
iteration 17000, 152.40234375 MB ...
iteration 18000, 151.609375 MB ...
iteration 19000, 155.484375 MB ...
So memory use seems stable in this case.
I'll try with your pytorch version.
I tried a version with fetch
, but it doesn't help speed with these large random tiles.
To preface, I am mostly new to pyvips and far from an expert, so apologies if something is unclear or I am neglecting something obvious!
I am training a convolutional neural network from 224x224 tiles extracted on the fly from 269 whole-slide images (WSIs), which are pyramidal tiffs. I am using Pytorch with a custom dataset to do so (I am using code I've modified - partially by switching from openslide to pyvips - from this script: https://github.com/MSKCC-Computational-Pathology/MIL-nature-medicine-2019/blob/master/MIL_train.py).
In the initialization of the custom data set, a list of 269 pyvips objects are created with new_from_file from the WSI paths and stored to self. In the getitem of this dataset, specific tiles are extracted from the pyvips objects stored in that list (in training, the Pytorch dataloader will request millions of tiles using this getitem per epoch).
When I train my model, I find that there is likely a memory leak somewhere, as over the course of one epoch (comprising about 5.3 million tiles), my memory use increases continuously until it fills up and the program crashes. I have run into this both on my work machine which has 32gb of memory as well as on clusters which have significantly more. After doing some reading, I added pyvips.cache_set_max(0) near the top of my file since I suspected that this increasing memory consumption was just the cache filling up. However, I still see the sustained increase in memory leading to my memory filling up even with the cache set to zero. I am using the latest conda installation of pyvips with Python 3.7.6 and Ubuntu 18.04.4 LTS. I am also using access="sequential" in my new_from_file calls which has not stopped the leak either.
I'm fairly sure that pyvips is the source of the leak since when I move the pyvip new_from_file object initialization to be within the getitem function instead of in the init function, so that the pyvips objects are re-created every time a tile is requested, my memory use stays constant. However, I've found that doing so substantially increases training time (something like 4x slower) since this means new_from_file is called every time a tile is accessed as opposed to only once per slide at the start of the program.
Here is a simplified version of the custom dataset I am using, adapted from Gabriele Campanella's code linked above, which is configured for use in a pytorch dataloader:
Note that lib['slides'] is a 269-long list of paths to the pyramidal tiffs. The relevant bits are the loop in init where the list of pyvips objects is created using new_from_file and stored in self.slides, and the part in the getitem function where extract_area yields a tile of the desired coordinates from the stored pyvips object in self.slides. If it would also be helpful to see other parts of the training script where the pytorch dataloader is created from this dataset and where the dataloader is iterated over, I'm happy to include more (although the script which I link above also shows these things, which I have barely changed).
Thanks in advance for any help!