Closed smilesun closed 8 months ago
in the new pr
https://github.com/marrlab/DomainLab/pull/769
replacing the accidently merged pr https://github.com/marrlab/DomainLab/pull/311/files
sh test_fishr.sh
domainlab/algos/trainers/train_fishr.py:70: in var_grads_and_loss dict_var_grads_single_domain = self.cal_dict_variance_grads(tensor_x, vec_y) domainlab/algos/trainers/train_fishr.py:160: in cal_dict_variance_grads loss.backward( ../anaconda3/lib/python3.9/site-packages/torch/_tensor.py:396: in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) ../anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:166: in backward grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ outputs = (tensor([0.4516, 1.0019], grad_fn=<BackwardHookFunctionBackward>),), grads = (None,), is_grads_batched = False def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor], is_grads_batched: bool) -> Tuple[_OptionalTensor, ...]: new_grads: List[_OptionalTensor] = [] for out, grad in zip(outputs, grads): if isinstance(grad, torch.Tensor): grad_shape = grad.shape if not is_grads_batched else grad.shape[1:] if not out.shape == grad_shape: if is_grads_batched: raise RuntimeError("If `is_grads_batched=True`, we interpret the first " "dimension of each grad_output as the batch dimension. " "The sizes of the remaining dimensions are expected to match " "the shape of corresponding output, but a mismatch " "was detected: grad_output[" + str(grads.index(grad)) + "] has a shape of " + str(grad.shape) + " and output[" + str(outputs.index(out)) + "] has a shape of " + str(out.shape) + ". " "If you only want some tensors in `grad_output` to be considered "
Update train_fishr.py: sum loss before backprop 34b3507
loss should be a scalr, using loss=loss.sum() fixed the issue
in the new pr
https://github.com/marrlab/DomainLab/pull/769
replacing the accidently merged pr https://github.com/marrlab/DomainLab/pull/311/files
sh test_fishr.sh