marrlab / DomainLab

modular domain generalization: https://pypi.org/project/domainlab/
https://marrlab.github.io/DomainLab/
MIT License
42 stars 2 forks source link

fishr re-pr backprop error #770

Closed smilesun closed 8 months ago

smilesun commented 8 months ago

in the new pr

https://github.com/marrlab/DomainLab/pull/769

replacing the accidently merged pr https://github.com/marrlab/DomainLab/pull/311/files

sh test_fishr.sh

domainlab/algos/trainers/train_fishr.py:70: in var_grads_and_loss                                                              
    dict_var_grads_single_domain = self.cal_dict_variance_grads(tensor_x, vec_y)                                               
domainlab/algos/trainers/train_fishr.py:160: in cal_dict_variance_grads                                                        
    loss.backward(                                                                                                             
../anaconda3/lib/python3.9/site-packages/torch/_tensor.py:396: in backward                                                     
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)                                         
../anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:166: in backward                                           
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)                                                
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

outputs = (tensor([0.4516, 1.0019], grad_fn=<BackwardHookFunctionBackward>),), grads = (None,), is_grads_batched = False       

    def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor],                                         
                    is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]:                                                    
        new_grads: List[_OptionalTensor] = []                                                                                  
        for out, grad in zip(outputs, grads):                                                                                  
            if isinstance(grad, torch.Tensor):                                                                                 
                grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]                                            
                if not out.shape == grad_shape:                                                                                
                    if is_grads_batched:                                                                                       
                        raise RuntimeError("If `is_grads_batched=True`, we interpret the first "                               
                                           "dimension of each grad_output as the batch dimension. "                            
                                           "The sizes of the remaining dimensions are expected to match "                      
                                           "the shape of corresponding output, but a mismatch "                                
                                           "was detected: grad_output["                                                        
                                           + str(grads.index(grad)) + "] has a shape of "                                      
                                           + str(grad.shape) + " and output["                                                  
                                           + str(outputs.index(out)) + "] has a shape of "                                     
                                           + str(out.shape) + ". "                                                             
                                           "If you only want some tensors in `grad_output` to be considered "
smilesun commented 8 months ago

Update train_fishr.py: sum loss before backprop 34b3507

smilesun commented 8 months ago

loss should be a scalr, using loss=loss.sum() fixed the issue