skhu101 / GM-NAS

Code for our ICLR'2022 paper "Generalizing Few-Shot NAS with Gradient Matching"
MIT License
21 stars 2 forks source link

iter(dataset) in training loop for NASBench201 #4

Open alec-flowers opened 2 years ago

alec-flowers commented 2 years ago

In WS-GM/nasbench201/train_search.py the train function has input, target = next(iter(train_queue)) inside the batches for loop. This creates an iter object every batch which significantly slows down training. It also makes it so 1 epoch doesn't loop over the entire dataset. Because iter(train_queue) is initialized every batch, it grabs the first sample, then reinitializes, grabs the first sample etc. You get a completely random sample from your dataset, but there is no guarantee that it is disjoint.