ShiqiYu / OpenGait

A flexible and extensible framework for gait recognition. You can focus on designing your own models and comparing with state-of-the-arts easily with the help of OpenGait.
724 stars 166 forks source link

The problem of freezing part of the network structure during training #230

Open enemy1205 opened 3 months ago

enemy1205 commented 3 months ago

When I refer to the network freezing mechanism in biggait, a certain part of the network is frozen after a certain round of training, but the following errors are reported when the training goes to the specified iterations

reference

    def forward(self, inputs):
        if self.training:
            if self.iteration == 500 and '.pt' in self.pretrained_mask_branch:
                self.init_Mask_Branch()
            if self.iteration >= 500:
                self.Mask_Branch.eval()
                self.Mask_Branch.requires_grad_(False)

Mine version

    def forward(self, inputs):
        if self.training:
            if self.iteration >= 60000:
                self.netA.eval()
                self.netA.requires_grad_(False)

Error message

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 1: 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
 In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
    if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
making sure all `forward` function outputs participate in calculating loss. 
If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
Parameter indices which did not receive grad for rank 0: 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
 In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1556588) of binary: /home/xxx/miniconda3/envs/gait/bin/python
Bugjudger commented 1 month ago

add find_unused_parameters=True