Zj-BinXia / SSL

This project is the official implementation of 'Structured Sparsity Learning for Efficient Video Super-Resolution', CVPR2023
95 stars 6 forks source link

Code confusion #2

Closed zxd-cqu closed 1 year ago

zxd-cqu commented 1 year ago
    def train(self):
        self.model.train()
        self.train_sampler.set_epoch(100000)
        self.prefetcher.reset()
        train_data = self.prefetcher.next()
        while train_data is not None:
            self.total_iter += 1

            self.lq = train_data['lq'].to(self.device)
            self.gt = train_data['gt'].to(self.device)

            finished = self.optimize_parameters(self.total_iter)
            if finished:
                return True

In this code snippet(from SSL/blob/master/basicsr/pruner/SSL_pruner.py), while training the model in a while loop, it might be necessary to add train_data = self.prefetcher.next() within the loop. Without the complete code context, but based on your description, it appears that the same data is being used repeatedly for model training without fetching new data.

should it be or not:

    def train(self):
        self.model.train()
        self.train_sampler.set_epoch(100000)
        self.prefetcher.reset()
        train_data = self.prefetcher.next()
        while train_data is not None:
            self.total_iter += 1

            self.lq = train_data['lq'].to(self.device)
            self.gt = train_data['gt'].to(self.device)

            finished = self.optimize_parameters(self.total_iter)
            train_data = self.prefetcher.next()  #add this ?
            if finished:
                return True
Zj-BinXia commented 1 year ago

Yes, you are right. I just found out that I seem to have deleted this sentence by mistake while deleting the debug code. I have added it now.