Closed jveitchmichaelis closed 4 years ago
@jveitchmichaelis these images are mainly intended for validating your labels, as many custom dataset users incorrectly format their data. One thing we are doing is trying to use tensorboard more extensively in the new repo.
Yes, I understand that some of these features may be very useful, and we may want to look at implementing them. One thing that comes to mind though is that anecdotal results partway through training from an incomplete model would seem to me to invite a great deal of questions and uncertainty and issue raising from (particularly inexperienced) users who do not like how their boxes look in epoch 30 out of 300 and want answers about FPs, poor regressions etc. We already get inundated with issues about people trying to shortcut training and not liking their results (i.e. training 10 epochs and then asking why mAP is low).
Another point is simply size considerations. The images are currently pngs because of a matplotlib bug from a few months ago, but this has been cleared up and can now be reverted to jpgs. Even at jpg compression, each individual image is 1.3MB currently, so it is not feasable to save a seperate image per epoch, nor to accumulate these on tensorboard every epoch, which would balloon one's tb file to over a GB for full training. Do you know if there's a way to simply display the most recent image in tensorboard without actually saving and growing the tb file?
And speed is also a factor. Plotting takes several seconds for 1 image, so small datasets would experience a significant slowdown in overall training time.
About the dataloader, we did use pytorch transforms initially, but they were extremely slow compared to the current implementation. This may have changed in the intervening time, if you have profiling results comparing that would be very useful!
@jveitchmichaelis BTW, in terms of dataloading, our speed constraint now is simply the loading of the images from the hard drive / SSD, i.e. the line below. For smaller models this may dominate the total training time even beyond model forward and backward passes. We have a --cache option for training which caches smaller datasets into RAM to significantly speed up training, but this is not feasable for larger datasets.
We've had requests for a dali implementation to address this. I think on the dataloader side this might have the biggest impact of all on speed of coco training. https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/examples/pytorch/pytorch-basic_example.html
In regards to the training images, if you can compact your code to the most minimal implementation as a plot function at the end of utils/utils.py we can review a PR there. Tensorboard functionality would be a plus also, but again being very attentive to profiling speed and file size considerations.
Will do - thought it would be best to opening a discussion here first. It seemed like a simple fix, but actually brought up a lot of architecture questions...
Dealing with user queries is an interesting issue. Plotting images every val step is the default behaviour in the Tensorflow object detection api, for example, which is why I decided to add it here, but it only plots a single random image afaik. Not sure you get round that aside from improving documentation and referring people to that. There is always going to be a balance between making an automatic system that does everything, but that gives you little insight, and a more flexible system that can be confusing or let you shoot yourself in the foot. If you look at the extreme end of this, Google's AutoML gives excellent results but is almost totally opaque to the user.
Size I can see might be a problem. The mitigation for this is plot less, or fewer images (or small figure), or not at all, and it would make sense if this were optional. I work with predominantly non-RGB data and this is also the case with satellite imagery. SSD read speed dominates over the forward/backward pass at some point for that, and throwing more cores usually helps to buffer up data (provided your SSD can cope). That said, I think worrying about a second or two of plotting speed when an epoch takes 5-10 minutes seems like premature optimisation. I can do some benchmarking there.
I know if you publish to tensorboard with the same global step, it just adds the image to a sequence rather than overwriting it. Not sure there's a way around that.
I suspect a large speed component is I/O storing and reading the figure from disk. You could probably throw away matplotlib and plot the images into an array directly with OpenCV or Numpy primitives. You'd just need to annotate them separately and then cat them into a grid.
DALI looks fun, though I'd be curious to see if your constraint is IO or the augmentation speed.
PyTorch uses PIL which isn't great so not surprised it's slow. I can do some testing, though Albumentations already has benchmarks against built-in implementations. I think this is partly why a lot of Kagglers use it https://github.com/albumentations-team/albumentations#benchmarking-results
Btw In terms of dataloading, I would expect you'd get a tiny boost in performance if you perform normalisation outside the training loop, since the work will be dispatched to the dataloader threads. Perhaps not if you're IO bound though.
@jveitchmichaelis yes, lots of factors. Albumentations looks interesting, though of course we are wary about adding more dependencies. We don't have the time to test it out for ourselves, but if you can demonstrate improvements over https://github.com/ultralytics/yolov3#speed in before and after tests we'd be happy to do a PR.
For the images I suppose the best intermediate solution to simply plot predictions_latest.jpg instead of naming per batch, then they are automatically written (not sure about tb but that's extra anyway).
@jveitchmichaelis ah also, about one of your points, we used to normalize inside the dataloader, we moved it out because it seems the RAM-GPU info transfer speed is a severe bottleneck, so this way we only need to send images to the GPU as uint8 (i.e. 4x less data transfer).
We had our hopes up this would make some sort of significant impact, but in practice training speeds remained unchanged.
Ok I'll take a look at this today. I'll try an OpenCV solution, I would think it's almost certainly going to be faster.
Also on augmentation, I'll run some benchmarks against:
# Augment imagespace
if not self.mosaic:
img, labels = random_affine(img, labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'])
# Augment colorspace
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
So currently it's just HSV and random affine that you're running? Cutout is currently commented. I'll try and get it set up in GCP, I have some credits which need burning. Should only need to run for a few epochs I think.
We had our hopes up this would make some sort of significant impact, but in practice training speeds remained unchanged.
Interesting on the normalisation front. So you expected an improvement from moving everything to the dataloder and there wasn't? Or the other way around?
I guess this is also easy to benchmark, since if I'm testing Albumentations I can move everything into the same composition pipeline.
@jveitchmichaelis ah yes, I could never get cutout to produce better training results. Every experiment I ran it hurt the mAP a bit.
The current augmentation is:
The image space transformations are done on a mosaic of 4 images, i.e. four 640x640 images are grouped into a 1280x1280 square, then random_affine() is applied to the 1280x1280 image, with an argument to remove 320 pixels of border, to produce a resultant 640x640 mosaicd augmented image that goes into the batch. Importantly the border removal is specified by the last argument and is also taken care of by random_affine()
Doing some simple benchmarks, I found that Albumentations' HSV augmentation is about twice as fast. Their implementation:
def _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift):
dtype = img.dtype
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hue, sat, val = cv2.split(img)
lut_hue = np.arange(0, 256, dtype=np.int16)
lut_hue = np.mod(lut_hue + hue_shift, 180).astype(dtype)
lut_sat = np.arange(0, 256, dtype=np.int16)
lut_sat = np.clip(lut_sat + sat_shift, 0, 255).astype(dtype)
lut_val = np.arange(0, 256, dtype=np.int16)
lut_val = np.clip(lut_val + val_shift, 0, 255).astype(dtype)
hue = cv2.LUT(hue, lut_hue)
sat = cv2.LUT(sat, lut_sat)
val = cv2.LUT(val, lut_val)
img = cv2.merge((hue, sat, val)).astype(dtype)
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
vs
def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
x = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
img_hsv = (cv2.cvtColor(img, cv2.COLOR_BGR2HSV) * x).clip(None, 255).astype(np.uint8)
np.clip(img_hsv[:, :, 0], None, 179, out=img_hsv[:, :, 0]) # inplace hue clip (0 - 179 deg)
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
Theirs takes about 8ms on full resolution Zidane vs 19ms for the one here. It would need a bit of modification to figure out what the gains should convert to since it's additive not multiplicative.
Affine warping is the same - they implement the affine transform in the same way using cv2.warpAffine
and in fact they don't provide the option for shearing (though there are more options for other types of warp). There are some minor differences about how translation is handled, but the timings are almost identical.
I didn't test speed for bounding box transforms yet as I expected that would be minimal compared to image ops. I also didn't check vertical/horizontal flip because I assume you're not going to beat np.flipud
/fliplr
.
Initial conclusion is probably to update the hsv augmenter (and test it). Otherwise it's potentially something to look into in the future to see if extra augmentation helps, as there's a richer variety of stuff to choose from (and for custom data that might be useful). Intermediate (best?) solution would be to provide an example to users on how to write a custom dataloader with their own augmentation pipeline.
On flipping:
def vflip(img):
return np.ascontiguousarray(img[::-1, ...])
def hflip(img):
return np.ascontiguousarray(img[:, ::-1, ...])
%timeit vflip(np.random.randint(low=0, high=255, size=(640,512,3)))
%timeit np.flipud(np.random.randint(low=0, high=255, size=(640,512,3)))
%timeit cv2.flip(np.random.randint(low=0, high=255, size=(640,512,3)), 0) #vflip
%timeit hflip(np.random.randint(low=0, high=255, size=(640,512,3)))
%timeit np.fliplr(np.random.randint(low=0, high=255, size=(640,512,3)))
%timeit cv2.flip(np.random.randint(low=0, high=255, size=(640,512,3)), 1) #hflip
3.69 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.71 ms ± 143 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) # Numpy
4.02 ms ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.51 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.02 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) # Numpy
4.26 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Definitely seems an advantage to using Numpy over OpenCV here, but should check this further. I think this is only used in webcam loading atm.
@jveitchmichaelis ah your HSV speedup is huge! That's great news. We should definitely implement their version then.
About the flips, I see the contiguous call there. This can be very slow sometimes, which is why I only call it once at the very end, but in general contiguous data will provide overall speedup through all the later operations, so it is beneficial. Hflip and vflip are definitely booleans that should be added to the hyps in train.py as well.
A wider question I have is do you have profiling capability with GPU? I used a macbook pro with 1080ti eGPU in the past, but lost that capability last year after a macos update, so now I just develop on cpu, and push to GCP for actual training. This means I've lost my profiling capability though, which I used to do with Spyder line profiler training a few batches with @profile decorators for example before the dataloader functions in datasets.py.
EDIT: Now that I think about it, profiling all the datasets.py functions should be fine without GPU... as no data in the file ever makes its way to a cuda device. Hmm ok I'll give it a shot today to see where the bottlenecks are.
EDIT2: I see their HSV function completely avoids indexing operations, very nice. Indexing can be extremely slow in general, this probably explains the speedup compared to our in-house function.
I have an Intel 6700K + 1080ti here (we also have a Ryzen 3900X with 2x2070s in the office), so if you have some specific stuff you'd like to profile I can give it a go. I just ran the tests in a notebook on some random numpy arrays + the sample images so my numbers are definitely CPU. Maybe this is where DALI might help?
I also tried caching the lookup tables to save allocating the arange
s every time, but I don't think it made an enormous amount of difference - within the error of %timeit anyway.
@jveitchmichaelis ok I did some lineprofiling in datasets.py. Back to my old game. This is for 16 epochs of coco64.data at img size 640, so about 1000 images total. Sorry for the long message, we can delete this later.
Caching labels: 0%| | 0/64 [00:00<?, ?it/s]
Caching labels (63 found, 1 missing, 0 empty, 0 duplicate, for 64 images): 100%|██████████| 64/64 [00:00<00:00, 3017.01it/s]
Caching labels: 0%| | 0/64 [00:00<?, ?it/s]
Caching labels (63 found, 1 missing, 0 empty, 0 duplicate, for 64 images): 100%|██████████| 64/64 [00:00<00:00, 3307.00it/s]
Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex
Namespace(adam=False, batch_size=16, bucket='', cache_images=False, cfg='cfg/yolov3-spp.cfg', data='data/coco64.data', device='', epochs=16, evolve=False, img_size=[320, 640], multi_scale=False, name='', nosave=False, notest=False, rect=False, resume=False, single_cls=False, weights='weights/yolov3-spp-ultralytics.pt')
Using CPU
Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/
Model Summary: 225 layers, 6.29987e+07 parameters, 6.29987e+07 gradients, 118.0 GFLOPS
Image sizes 320 - 640 train, 640 test
Using 0 dataloader workers
Starting training for 16 epochs...
Wrote profile results to /Users/glennjocher/.spyder-py3/lineprofiler.results
Timer unit: 1e-06 s
Total time: 0.039052 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: __init__ at line 259
Line # Hits Time Per Hit % Time Line Contents
==============================================================
259 @profile
260 def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
261 cache_labels=True, cache_images=False, single_cls=False):
262 2 107.0 53.5 0.3 path = str(Path(path)) # os-agnostic
263 2 54.0 27.0 0.1 assert os.path.isfile(path), 'File not found %s. See %s' % (path, help_url)
264 2 133.0 66.5 0.3 with open(path, 'r') as f:
265 2 643.0 321.5 1.6 self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
266 if os.path.splitext(x)[-1].lower() in img_formats]
267
268 2 6.0 3.0 0.0 n = len(self.img_files)
269 2 5.0 2.5 0.0 assert n > 0, 'No images found in %s. See %s' % (path, help_url)
270 2 72.0 36.0 0.2 bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
271 2 12.0 6.0 0.0 nb = bi[-1] + 1 # number of batches
272
273 2 8.0 4.0 0.0 self.n = n
274 2 5.0 2.5 0.0 self.batch = bi # batch index of image
275 2 5.0 2.5 0.0 self.img_size = img_size
276 2 5.0 2.5 0.0 self.augment = augment
277 2 5.0 2.5 0.0 self.hyp = hyp
278 2 6.0 3.0 0.0 self.image_weights = image_weights
279 2 7.0 3.5 0.0 self.rect = False if image_weights else rect
280 2 6.0 3.0 0.0 self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
281
282 # Define labels
283 2 6.0 3.0 0.0 self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
284 2 592.0 296.0 1.5 for x in self.img_files]
285
286 # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
287 2 5.0 2.5 0.0 if self.rect:
288 # Read image shapes (wh)
289 1 2.0 2.0 0.0 sp = path.replace('.txt', '.shapes') # shapefile path
290 1 2.0 2.0 0.0 try:
291 1 43.0 43.0 0.1 with open(sp, 'r') as f: # read existing shapefile
292 1 48.0 48.0 0.1 s = [x.split() for x in f.read().splitlines()]
293 1 15.0 15.0 0.0 assert len(s) == n, 'Shapefile out of sync'
294 except:
295 s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
296 np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
297
298 # Sort by aspect ratio
299 1 50.0 50.0 0.1 s = np.array(s, dtype=np.float64)
300 1 11.0 11.0 0.0 ar = s[:, 1] / s[:, 0] # aspect ratio
301 1 17.0 17.0 0.0 i = ar.argsort()
302 1 29.0 29.0 0.1 self.img_files = [self.img_files[i] for i in i]
303 1 24.0 24.0 0.1 self.label_files = [self.label_files[i] for i in i]
304 1 15.0 15.0 0.0 self.shapes = s[i] # wh
305 1 3.0 3.0 0.0 ar = ar[i]
306
307 # Set training image shapes
308 1 7.0 7.0 0.0 shapes = [[1, 1]] * nb
309 5 14.0 2.8 0.0 for i in range(nb):
310 4 31.0 7.8 0.1 ari = ar[bi == i]
311 4 62.0 15.5 0.2 mini, maxi = ari.min(), ari.max()
312 4 13.0 3.2 0.0 if maxi < 1:
313 2 5.0 2.5 0.0 shapes[i] = [maxi, 1]
314 2 5.0 2.5 0.0 elif mini > 1:
315 1 3.0 3.0 0.0 shapes[i] = [1, 1 / mini]
316
317 1 26.0 26.0 0.1 self.batch_shapes = np.ceil(np.array(shapes) * img_size / 64.).astype(np.int) * 64
318
319 # Preload labels (required for weighted CE training)
320 2 8.0 4.0 0.0 self.imgs = [None] * n
321 2 6.0 3.0 0.0 self.labels = [None] * n
322 2 4.0 2.0 0.0 if cache_labels or image_weights: # cache labels for faster training
323 2 10.0 5.0 0.0 self.labels = [np.zeros((0, 5))] * n
324 2 4.0 2.0 0.0 extract_bounding_boxes = False
325 2 5.0 2.5 0.0 create_datasubset = False
326 2 2132.0 1066.0 5.5 pbar = tqdm(self.label_files, desc='Caching labels')
327 2 7.0 3.5 0.0 nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
328 130 1603.0 12.3 4.1 for i, file in enumerate(pbar):
329 128 278.0 2.2 0.7 try:
330 128 5922.0 46.3 15.2 with open(file, 'r') as f:
331 126 5998.0 47.6 15.4 l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
332 2 5.0 2.5 0.0 except:
333 2 4.0 2.0 0.0 nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing
334 2 6.0 3.0 0.0 continue
335
336 126 375.0 3.0 1.0 if l.shape[0]:
337 126 317.0 2.5 0.8 assert l.shape[1] == 5, '> 5 label columns: %s' % file
338 126 1914.0 15.2 4.9 assert (l >= 0).all(), 'negative labels: %s' % file
339 126 1691.0 13.4 4.3 assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
340 126 14170.0 112.5 36.3 if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
341 nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
342 126 329.0 2.6 0.8 if single_cls:
343 l[:, 0] = 0 # force dataset into single-class mode
344 126 321.0 2.5 0.8 self.labels[i] = l
345 126 288.0 2.3 0.7 nf += 1 # file found
346
347 # Create subdataset (a smaller dataset)
348 126 290.0 2.3 0.7 if create_datasubset and ns < 1E4:
349 if ns == 0:
350 create_folder(path='./datasubset')
351 os.makedirs('./datasubset/images')
352 exclude_classes = 43
353 if exclude_classes not in l[:, 0]:
354 ns += 1
355 # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
356 with open('./datasubset/images.txt', 'a') as f:
357 f.write(self.img_files[i] + '\n')
358
359 # Extract object detection boxes for a second stage classifier
360 126 304.0 2.4 0.8 if extract_bounding_boxes:
361 p = Path(self.img_files[i])
362 img = cv2.imread(str(p))
363 h, w = img.shape[:2]
364 for j, x in enumerate(l):
365 f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
366 if not os.path.exists(Path(f).parent):
367 os.makedirs(Path(f).parent) # make new output folder
368
369 b = x[1:] * [w, h, w, h] # box
370 b[2:] = b[2:].max() # rectangle to square
371 b[2:] = b[2:] * 1.3 + 30 # pad
372 b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
373
374 b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
375 b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
376 assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
377 else:
378 ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
379 # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
380
381 126 271.0 2.2 0.7 pbar.desc = 'Caching labels (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
382 126 662.0 5.3 1.7 nf, nm, ne, nd, n)
383 2 7.0 3.5 0.0 assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
384
385 # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
386 2 5.0 2.5 0.0 if cache_images: # if training
387 gb = 0 # Gigabytes of cached images
388 pbar = tqdm(range(len(self.img_files)), desc='Caching images')
389 self.img_hw0, self.img_hw = [None] * n, [None] * n
390 for i in pbar: # max 10k images
391 self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
392 gb += self.imgs[i].nbytes
393 pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
394
395 # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
396 2 4.0 2.0 0.0 detect_corrupted_images = False
397 2 5.0 2.5 0.0 if detect_corrupted_images:
398 from skimage import io # conda install -c conda-forge scikit-image
399 for file in tqdm(self.img_files, desc='Detecting corrupted images'):
400 try:
401 _ = io.imread(file)
402 except:
403 print('Corrupted image detected: %s' % file)
Total time: 39.4643 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: __getitem__ at line 414
Line # Hits Time Per Hit % Time Line Contents
==============================================================
414 @profile
415 def __getitem__(self, index):
416 1024 992.0 1.0 0.0 if self.image_weights:
417 index = self.indices[index]
418
419 1024 770.0 0.8 0.0 hyp = self.hyp
420 1024 783.0 0.8 0.0 if self.mosaic:
421 # Load mosaic
422 1024 24187640.0 23620.7 61.3 img, labels = load_mosaic(self, index)
423 1024 882.0 0.9 0.0 shapes = None
424
425 else:
426 # Load image
427 img, (h0, w0), (h, w) = load_image(self, index)
428
429 # Letterbox
430 shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
431 img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
432 shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
433
434 # Load labels
435 labels = []
436 x = self.labels[index]
437 if x is not None and x.size > 0:
438 # Normalized xywh to pixel xyxy format
439 labels = x.copy()
440 labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
441 labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
442 labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
443 labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
444
445 1024 1119.0 1.1 0.0 if self.augment:
446 # Augment imagespace
447 1024 744.0 0.7 0.0 if not self.mosaic:
448 img, labels = random_affine(img, labels,
449 degrees=hyp['degrees'],
450 translate=hyp['translate'],
451 scale=hyp['scale'],
452 shear=hyp['shear'])
453
454 # Augment colorspace
455 1024 14598738.0 14256.6 37.0 augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
456
457 # Apply cutouts
458 # if random.random() < 0.9:
459 # labels = cutout(img, labels)
460
461 1024 1708.0 1.7 0.0 nL = len(labels) # number of labels
462 1024 847.0 0.8 0.0 if nL:
463 # convert xyxy to xywh
464 1019 50403.0 49.5 0.1 labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
465
466 # Normalize coordinates 0 - 1
467 1019 22027.0 21.6 0.1 labels[:, [2, 4]] /= img.shape[0] # height
468 1019 10981.0 10.8 0.0 labels[:, [1, 3]] /= img.shape[1] # width
469
470 1024 1429.0 1.4 0.0 if self.augment:
471 # random left-right flip
472 1024 671.0 0.7 0.0 lr_flip = True
473 1024 2369.0 2.3 0.0 if lr_flip and random.random() < 0.5:
474 506 4135.0 8.2 0.0 img = np.fliplr(img)
475 506 353.0 0.7 0.0 if nL:
476 503 2398.0 4.8 0.0 labels[:, 1] = 1 - labels[:, 1]
477
478 # random up-down flip
479 1024 618.0 0.6 0.0 ud_flip = False
480 1024 634.0 0.6 0.0 if ud_flip and random.random() < 0.5:
481 img = np.flipud(img)
482 if nL:
483 labels[:, 2] = 1 - labels[:, 2]
484
485 1024 31981.0 31.2 0.1 labels_out = torch.zeros((nL, 6))
486 1024 838.0 0.8 0.0 if nL:
487 1019 39957.0 39.2 0.1 labels_out[:, 1:] = torch.from_numpy(labels)
488
489 # Convert
490 1024 4148.0 4.1 0.0 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
491 1024 489124.0 477.7 1.2 img = np.ascontiguousarray(img)
492
493 1024 8059.0 7.9 0.0 return torch.from_numpy(img), labels_out, self.img_files[index], shapes
Total time: 20.0311 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: load_image at line 502
Line # Hits Time Per Hit % Time Line Contents
==============================================================
502 @profile
503 def load_image(self, index):
504 # loads 1 image from dataset, returns img, original hw, resized hw
505 4096 4031.0 1.0 0.0 img = self.imgs[index]
506 4096 1970.0 0.5 0.0 if img is None: # not cached
507 4096 2916.0 0.7 0.0 img_path = self.img_files[index]
508 4096 19327841.0 4718.7 96.5 img = cv2.imread(img_path) # BGR
509 4096 6204.0 1.5 0.0 assert img is not None, 'Image Not Found ' + img_path
510 4096 7073.0 1.7 0.0 h0, w0 = img.shape[:2] # orig hw
511 4096 10713.0 2.6 0.1 r = self.img_size / max(h0, w0) # resize image to img_size
512 4096 4143.0 1.0 0.0 if r < 1 or (self.augment and r != 1): # always resize down, only resize up if training with augmentation
513 702 655.0 0.9 0.0 interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
514 702 661470.0 942.3 3.3 img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
515 4096 4106.0 1.0 0.0 return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
516 else:
517 return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
Total time: 14.505 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: augment_hsv at line 519
Line # Hits Time Per Hit % Time Line Contents
==============================================================
519 @profile
520 def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
521 1024 21550.0 21.0 0.1 x = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
522 1024 12402276.0 12111.6 85.5 img_hsv = (cv2.cvtColor(img, cv2.COLOR_BGR2HSV) * x).clip(None, 255).astype(np.uint8)
523 1024 335501.0 327.6 2.3 np.clip(img_hsv[:, :, 0], None, 179, out=img_hsv[:, :, 0]) # inplace hue clip (0 - 179 deg)
524 1024 1745712.0 1704.8 12.0 cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
Total time: 24.0236 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: load_mosaic at line 531
Line # Hits Time Per Hit % Time Line Contents
==============================================================
531 @profile
532 def load_mosaic(self, index):
533 # loads images in a mosaic
534
535 1024 1102.0 1.1 0.0 labels4 = []
536 1024 1078.0 1.1 0.0 s = self.img_size
537 1024 10111.0 9.9 0.0 xc, yc = [int(random.uniform(s * 0.5, s * 1.5)) for _ in range(2)] # mosaic center x, y
538 1024 22212.0 21.7 0.1 indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
539 5120 7778.0 1.5 0.0 for i, index in enumerate(indices):
540 # Load image
541 4096 20106198.0 4908.7 83.7 img, _, (h, w) = load_image(self, index)
542
543 # place img in img4
544 4096 4202.0 1.0 0.0 if i == 0: # top left
545 1024 316638.0 309.2 1.3 img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
546 1024 3165.0 3.1 0.0 x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
547 1024 1456.0 1.4 0.0 x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
548 3072 2804.0 0.9 0.0 elif i == 1: # top right
549 1024 2684.0 2.6 0.0 x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
550 1024 1622.0 1.6 0.0 x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
551 2048 1785.0 0.9 0.0 elif i == 2: # bottom left
552 1024 2675.0 2.6 0.0 x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
553 1024 1825.0 1.8 0.0 x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
554 1024 909.0 0.9 0.0 elif i == 3: # bottom right
555 1024 2560.0 2.5 0.0 x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
556 1024 1788.0 1.7 0.0 x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
557
558 4096 320291.0 78.2 1.3 img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
559 4096 4612.0 1.1 0.0 padw = x1a - x1b
560 4096 3669.0 0.9 0.0 padh = y1a - y1b
561
562 # Load labels
563 4096 7042.0 1.7 0.0 label_path = self.label_files[index]
564 4096 68768.0 16.8 0.3 if os.path.isfile(label_path):
565 4027 5609.0 1.4 0.0 x = self.labels[index]
566 4027 3988.0 1.0 0.0 if x is None: # labels not preloaded
567 with open(label_path, 'r') as f:
568 x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
569
570 4027 5652.0 1.4 0.0 if x.size > 0:
571 # Normalized xywh to pixel xyxy format
572 4027 11452.0 2.8 0.0 labels = x.copy()
573 4027 82746.0 20.5 0.3 labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
574 4027 31467.0 7.8 0.1 labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
575 4027 28048.0 7.0 0.1 labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
576 4027 26623.0 6.6 0.1 labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
577 else:
578 labels = np.zeros((0, 5), dtype=np.float32)
579 4027 6306.0 1.6 0.0 labels4.append(labels)
580
581 # Concat/clip labels
582 1024 1616.0 1.6 0.0 if len(labels4):
583 1024 11568.0 11.3 0.0 labels4 = np.concatenate(labels4, 0)
584 # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
585 1024 66430.0 64.9 0.3 np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
586
587 # Augment
588 # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
589 1024 1263.0 1.2 0.0 img4, labels4 = random_affine(img4, labels4,
590 1024 1976.0 1.9 0.0 degrees=self.hyp['degrees'] * 1,
591 1024 1127.0 1.1 0.0 translate=self.hyp['translate'] * 1,
592 1024 1016.0 1.0 0.0 scale=self.hyp['scale'] * 1,
593 1024 1006.0 1.0 0.0 shear=self.hyp['shear'] * 1,
594 1024 2837772.0 2771.3 11.8 border=-s // 2) # border to remove
595
596 1024 976.0 1.0 0.0 return img4, labels4
Total time: 0 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: letterbox at line 598
Line # Hits Time Per Hit % Time Line Contents
==============================================================
598 @profile
599 def letterbox(img, new_shape=(416, 416), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
600 # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
601 shape = img.shape[:2] # current shape [height, width]
602 if isinstance(new_shape, int):
603 new_shape = (new_shape, new_shape)
604
605 # Scale ratio (new / old)
606 r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
607 if not scaleup: # only scale down, do not scale up (for better test mAP)
608 r = min(r, 1.0)
609
610 # Compute padding
611 ratio = r, r # width, height ratios
612 new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
613 dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
614 if auto: # minimum rectangle
615 dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
616 elif scaleFill: # stretch
617 dw, dh = 0.0, 0.0
618 new_unpad = new_shape
619 ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios
620
621 dw /= 2 # divide padding into 2 sides
622 dh /= 2
623
624 if shape[::-1] != new_unpad: # resize
625 img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
626 top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
627 left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
628 img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
629 return img, ratio, (dw, dh)
Total time: 2.77499 s
File: /Users/glennjocher/PycharmProjects/yolov3/utils/datasets.py
Function: random_affine at line 631
Line # Hits Time Per Hit % Time Line Contents
==============================================================
631 @profile
632 def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, border=0):
633 # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
634 # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
635
636 1024 923.0 0.9 0.0 if targets is None: # targets = [cls, xyxy]
637 targets = []
638 1024 1483.0 1.4 0.1 height = img.shape[0] + border * 2
639 1024 948.0 0.9 0.0 width = img.shape[1] + border * 2
640
641 # Rotation and Scale
642 1024 11823.0 11.5 0.4 R = np.eye(3)
643 1024 4167.0 4.1 0.2 a = random.uniform(-degrees, degrees)
644 # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
645 1024 1777.0 1.7 0.1 s = random.uniform(1 - scale, 1 + scale)
646 1024 11052.0 10.8 0.4 R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
647
648 # Translation
649 1024 5681.0 5.5 0.2 T = np.eye(3)
650 1024 2525.0 2.5 0.1 T[0, 2] = random.uniform(-translate, translate) * img.shape[0] + border # x translation (pixels)
651 1024 1617.0 1.6 0.1 T[1, 2] = random.uniform(-translate, translate) * img.shape[1] + border # y translation (pixels)
652
653 # Shear
654 1024 4268.0 4.2 0.2 S = np.eye(3)
655 1024 3381.0 3.3 0.1 S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
656 1024 1663.0 1.6 0.1 S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
657
658 # Combined rotation matrix
659 1024 21452.0 20.9 0.8 M = S @ T @ R # ORDER IS IMPORTANT HERE!!
660 1024 953.0 0.9 0.0 if (border != 0) or (M != np.eye(3)).any(): # image changed
661 1024 2464487.0 2406.7 88.8 img = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=(114, 114, 114))
662
663 # Transform label coordinates
664 1024 1843.0 1.8 0.1 n = len(targets)
665 1024 846.0 0.8 0.0 if n:
666 # warp points
667 1024 14153.0 13.8 0.5 xy = np.ones((n * 4, 3))
668 1024 20811.0 20.3 0.7 xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
669 1024 14404.0 14.1 0.5 xy = (xy @ M.T)[:, :2].reshape(n, 8)
670
671 # create new boxes
672 1024 6902.0 6.7 0.2 x = xy[:, [0, 2, 4, 6]]
673 1024 4465.0 4.4 0.2 y = xy[:, [1, 3, 5, 7]]
674 1024 33448.0 32.7 1.2 xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
675
676 # # apply angle-based reduction of bounding boxes
677 # radians = a * math.pi / 180
678 # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
679 # x = (xy[:, 2] + xy[:, 0]) / 2
680 # y = (xy[:, 3] + xy[:, 1]) / 2
681 # w = (xy[:, 2] - xy[:, 0]) * reduction
682 # h = (xy[:, 3] - xy[:, 1]) * reduction
683 # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
684
685 # reject warped points outside of image
686 1024 56307.0 55.0 2.0 xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
687 1024 37507.0 36.6 1.4 xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
688 1024 4126.0 4.0 0.1 w = xy[:, 2] - xy[:, 0]
689 1024 1941.0 1.9 0.1 h = xy[:, 3] - xy[:, 1]
690 1024 1842.0 1.8 0.1 area = w * h
691 1024 6666.0 6.5 0.2 area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
692 1024 7517.0 7.3 0.3 ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
693 1024 13891.0 13.6 0.5 i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.2) & (ar < 10)
694
695 1024 4649.0 4.5 0.2 targets = targets[i]
696 1024 4707.0 4.6 0.2 targets[:, 1:5] = xy[i]
697
698 1024 769.0 0.8 0.0 return img, targets
@jveitchmichaelis 15f1343dfc203968b4c048ce7a6c5bd7e2387b13 cleans up the code a bit. I'll try your HSV augment now in the profiler. Ah, awesome, I see a big speedup as well. Load mosaic is about 10% faster now after the commit, though it might just be random effect.
Ok, so to implement this hsv fix, we can adopt their code, but it looks like their code accepts fixed hsv gains, so we need to use part of the existing function that generates these randomly from the hyps, which is super fast.
@jveitchmichaelis ok, I've merged the best parts of both HSV functions into v2 below. This looks like a great replacement, just need to verify actual outputs are similar.
Yeah I think the current code is just a multiplier. This one does a shift, so probably something like random.randint(-max_hue, max_hue)
for each. But we'd need do a bit of hyperparameter optimisation to check what values work - from some messing around, something like 20-50 is probably OK.
@jveitchmichaelis ohhh you're right, that part flew over my head. 🤔 Gonna need to get creative then, we want similar outcomes if possible, otherwise people might train and get different results and come asking questions.
@jveitchmichaelis looking at their code, I don't understand how shifts are beneficial here, because any sort of shift is going to result in clipping of pixels values. Gains between 0-1, like the current system don't have this problem. I admit I'm not a colorspace expert but I can't see the advantage of offsetting all the values by the same bias. What do you think?
EDIT: If I update the function to multiply, then it seems to produce the exact same results as the current implementation, i.e. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
Yeah I wondered about this, I think image editors apply a gain (e.g. 0-100% saturation) rather than a shift. So if it doesn't incur a big perf penalty I would stick to gain - it's also less to change and we know the results are good. (Though to be honest visually I couldn't see a huge difference when looking at shifted or gain adjusted images.)
@jveitchmichaelis allright, so here's what I propose for the new function. It's about 10% faster than the one you posted, and about 2-3X faster than the current implementation. If you want you can submit the PR for this so you get the credit. It's probably the right thing to do since you came up with the idea and did the initial profiling.
def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
dtype = img.dtype # uint8
x = np.arange(0, 256, dtype=np.int16)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
Cool, thanks! I'll get back to the image plotting :)
Sketch of a much faster function, takes about < 30 ms to generate the image for me:
from matplotlib import colors
def plot_images(images, targets, paths=None, fname='images.jpg', class_labels=True):
if isinstance(images, torch.Tensor):
images = images.cpu().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
# un-normalise
if np.max(images[0]) <= 1:
images *= 255
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, 16) # limit plot to 16 images
ns = np.ceil(bs ** 0.5) # number of subplots
mosaic_width = int(ns*w)
mosaic_height = int(ns*h)
mosaic = np.zeros((mosaic_height, mosaic_width, 3), dtype=np.uint8)
prop_cycle = plt.rcParams['axes.prop_cycle']
# https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
hex2rgb = lambda h : tuple(int(h[1+i:1+i+2], 16) for i in (0, 2, 4))
color_lut = [hex2rgb(h) for h in prop_cycle.by_key()['color']]
for i, image in enumerate(images):
block_x = int(w * (i // ns))
block_y = int(h * (i % ns))
mosaic[block_y:block_y+h, block_x:block_x+w,:] = image.transpose(1,2,0)
if targets is not None:
image_targets = targets[targets[:, 0] == i]
boxes = xywh2xyxy(image_targets[:,2:6]).T
classes = image_targets[:,1].astype('int')
boxes[[0, 2]] *= w
boxes[[0, 2]] += block_x
boxes[[1, 3]] *= h
boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T):
color = color_lut[int(classes[j]) % len(color_lut)]
box = box.astype(int)
cv2.rectangle(mosaic, (box[0], box[1]), (box[2], box[3]), color, thickness=3)
if class_labels:
# Class label (ID only)
class_str = str(classes[j])
cv2.rectangle(mosaic, (box[0], box[1]), (box[0]+20*len(class_str), box[1]-20), color, thickness=-1)
cv2.putText(mosaic, class_str, (box[0], box[1]), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.8, thickness=2, color=(255,255,255))
# Draw image filename labels
if paths is not None:
label = os.path.basename(paths[i])
cv2.rectangle(mosaic, (block_x, block_y), (block_x+17*len(label), block_y+50), 0, thickness=-1)
cv2.putText(mosaic, label, (block_x+20, block_y+40), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.8, thickness=2, color=(255,255,255))
# Image border
cv2.rectangle(mosaic, (block_x, block_y), (block_x+w, block_y+h), (255,255,255), thickness=3)
if fname is not None:
cv2.imwrite(fname, mosaic)
return mosaic
Sample output:
Since it returns a numpy array directly, you could feed that direclty into Tensorboard or you can save it. I suspect cv2.imwrite is also pretty quick compared to savefig.
I'll add some labels next.
@jveitchmichaelis oh this looks really good, and fast! That's a better use of space than what I have now. I don't think the bounding boxes need any labels, the colors are great, but would it be possible to inline a filename/title maybe overlaid on the upper part of each image?
So then the plan might then be to use 3 of these overlays:
Yeah small modification - I updated the code above:
The background for the image label is derived from the label size, but I need to check if that's actually robust... :)
With a call to imwrite, it's about 100-200ms (this is also a pretty enormous image..)
%timeit plot_image_label_mosaic(images, targets, fname="image.png")
%timeit plot_image_label_mosaic(images, targets, fname="image.jpg")
%timeit plot_image_label_mosaic(images, targets, fname=None)
193 ms ± 2.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
97.2 ms ± 901 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
25.9 ms ± 539 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
These are 640x512 images, so probably as large as most people will reasonably feed into a model.
@jveitchmichaelis awesome!! Yeah one warning is that people have huge filenames sometimes. I think I may have filename clipping in my implementation to guard against this. Even the coco names can get pretty long.
The jpg speed is great. I actually just made a commit to switch these to jpgs 37cbe89ef060c4b9789617b3e539a4c58a56d457
Ah so the image is directly sized by the underlying images. Right now its a fixed size, but it may not matter much.
This plots the basename at the moment, but could also add a hard limit. Also it would be trivial to just put a call to resize in, since the labels are all relative coords anyway. Might do that actually, e.g. with a max_size arg.
@jveitchmichaelis yeah good idea. Someone might show up trying to train HD images with 8 GPUs and break the system.
We also need to make the grid selection smart to the batch-size, for example I think I have code that defines the grid size as min(ceil(sqrt(batch_size)), 4). So bs 4 would be 2x2, bs 5 would be 3x3, bs 10 or higher would be 4x4
The grid size calculation is from your current code, so it should scale properly with different batch sizes. The default max is 16, and it can be changed via argument.
I've made some more updates, the font sizes now scale a bit better. There's a maximum size argument which is currently set to 640.
plot_images(images, targets, paths, max_subplots=16)
plot_images(images, targets, paths, max_subplots=9, max_size=300)
If you pass in an odd number, for example, the remaining subplots are blank
plot_images(images, targets, paths, max_subplots=3, max_size=300)
With optional confidence labels (if box size > 6):
@jveitchmichaelis looks great! Can you pass a white image for empty ones, and then I suppose good defaults would be plot_images(images, targets, paths, max_subplots=16, max_size=640). Other than that it looks really good!
You should also test with a few coco batches too to make sure that different shapes play well with the function. Test.py uses rectangular inference for example, and sorts the dataset by aspect ratio for min overall computation, so some batches may have extreme aspect ratios, i.e. the first batch will be all long rectangles horizontally, and the last batch will be long vertical rectangle images.
Sure will do (I don't actually have COCO downloaded at the moment, in progress...) I can run it with some test cases in the meantime.
I assume all images passed a batch are the same size?
Background change is easy, will do that as well.
Here you can use this code to download 2017 val images <1GB and create the test_batch0.jpg:
python, run from outside yolov3 folder
from yolov3.utils.google_utils import gdrive_download
gdrive_download('1Y6Kou6kEB0ZEMCCpJSKStCor4KAReE43','coco2017val.zip')
bash
cd yolov3
python3 test.py --img 608 --data coco2017.data
Seems to be OK. Benchmarks are also better. I had a job running before so they were a lot slower than they should be (again for batch size 16, 640x512 images) - practically real-time!
%timeit plot_images(images, targets, fname="image.png")
%timeit plot_images(images, targets, fname="image.jpg")
%timeit plot_images(images, targets, fname=None)
76.7 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
47.6 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.25 ms ± 62.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Ground truth:
Prediction (0.25 conf thresh):
Had to do a bit of fiddling to get the labels to look good on narrow images:
I'll clean up the code tomorrow and PR it. Only needs minimal changes, and I've added a small utility function to convert outputs from NMS back to targets for plotting.
@jveitchmichaelis ok awesome! Looking forward to it.
@jveitchmichaelis yeah speed looks awesome too!
Should be good to go, I've tested it with train/test modes with various aspect ratios but do check it before merging!
@jveitchmichaelis thanks bud, I'm excited. I've got a few fires to put out today, but I should be able to get this in sometime later this afternoon!
Cool! Thanks for merging, I guess this can be closed.
@jveitchmichaelis you're welcome! I'll get this merged soon. If you have any other suggestions or feedback, feel free to let me know. Thanks for your contribution!
🚀 Feature
Improved image batch plots during training
Motivation
It's nice to see labeling performance as training progresses, but there are a few issues:
The last point is really important I think. Seeing what the model is labeling is arguably more informative than looking at the curves in tensorboard, and it's also a sanity check to see that your labels are loading correctly.
Currently with a non-standard dataset, plotting images looks a bit like this:
Pitch
Here is the result on the FLIR ADAS dataset - for context after an epoch of training from scratch:
This function improves the above. If means and stds are provided, then we invert the normalisation. If you provide a multi-dimensional image, for example RGBD, you can select which channels are plotted. Boxes are coloured using matplotlib's property cycle, so it should be consistent within a dataset. Boxes are plotted with confidence labelled as well, and by default the alpha of the box is set to the confidence (or this can be disabled).
Note we need to pass the actual tensor in from
imgs
in the training loop because we may have performed augmentation (therefore the path alone isn't enough).It handles ground truth/predictions appropriately.
I've opted to pass in the mean/std and not have a default value. We could just set 255 and 1 by default, but this may still break for custom datasets due to this line:
https://github.com/ultralytics/yolov3/blob/3554ab07fbedc05d91d9e6907b96a62512d931d5/train.py#L237
Required modifications
This should work out of the box as a drop-in replacement, with some extra code to adjust the target array above. However, for best results (and for extensibility on other datasets) there needs to be a few extra infrastructure changes when calling test code.
Mainly
test()
function needs to accept some extra parameters, for example data set means/standard deviation if used, the summary writer so we can push to the same tensorboard instance, etc. This is mostly because these parameters are now out of scope whentrain
is called.I would suggest looking into a more tightly coupled train/test framework so that these sorts of things could be shared more efficiently, but I guess that's coming anyway! https://github.com/ultralytics/yolov3/issues/1093
Other thoughts
in fact I would suggest we could move all the current augmentation code over to Torch's built in stuff (for flips, hsv, affine). The main reason I don't like the transforms in PyTorch is that the image transformations force you to use PIL, which craps out with > 3 channels. Albumentations does most stuff in Numpy and is also very fast.
Summary
Basically this is good to PR, but it'd be good to get your input on how you want to handle e.g. adding extra args to
train()
and whether this is OK.