Randl / MobileNetV2-pytorch

Impementation of MobileNetV2 in pytorch
https://arxiv.org/abs/1801.04381
MIT License
271 stars 83 forks source link

How did you compute a metric like images/second? #9

Open skalyan opened 5 years ago

skalyan commented 5 years ago

I am trying to use standardized metric such as images/sec to arrive at relative training speeds for different frameworks(e.g. PyTorch and TF). Have you computed such a metric, if not, what do you think of this approach.

modified run.py to compute images_per_sec rate. lines modified highlighted with "[KAL]"

=========================================

`

def train(model, loader, epoch, optimizer, criterion, device, dtype, batch_size, log_interval, scheduler):

   model.train()
   correct1, correct5 = 0, 0 
   batch_time = AverageMeter()
   images_per_sec = AverageMeter()                                                                                                                                                           
   for batch_idx, (data, target) in enumerate(tqdm(loader)):
         if isinstance(scheduler, CyclicLR):
             scheduler.batch_step()
         data, target = data.to(device=device, dtype=dtype), target.to(device=device)

         # [KAL] Take timestamp
         end = time.time()                                                                                                                                                                        
         optimizer.zero_grad()
         output = model(data)                                                                                                                                                                   
         loss = criterion(output, target)
         loss.backward()
         optimizer.step()                                                                                                                                                                          
         corr = correct(output, target, topk=(1, 5))
         correct1 += corr[0]
         correct5 += corr[1]            

         # [KAL] compute processing time for batch                                                                                                                                                 
         batch_time.update(time.time() - end)       

         # [KAL] Based on batch size, calculate the images/sec rate                                                  
         images_per_sec.update(batch_size / (time.time() - end)) `
Randl commented 5 years ago

If you run the code you'll see tqdm progress bar which shows average time per batch, elapsed time and approximate time to finish. Second progress bar shows same for epochs.

However you might prefer to use a distributed framework for pytorch, which is supposed to provide better performance even on a single PC