Unless I'm doing something wrong, there is a corner case inside def __wct_core, when one of the matrices is basically a vector.
When calculating:
contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
cFSize[1] is 1 so there is a division by 0=> and we get a matrix full of NAN which is causing the SVD to fail.
For now, as a w/a inside
def __feature_wct
I've changed the condtion
if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
continue
to
if cont_mask[0].size <= 1 or styl_mask[0].size <= 1:
continue
to ignore labels that causing this issue.
Any idea why it happens and what is the best approach to fix it?
Unless I'm doing something wrong, there is a corner case inside def __wct_core, when one of the matrices is basically a vector. When calculating: contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
cFSize[1] is 1 so there is a division by 0=> and we get a matrix full of NAN which is causing the SVD to fail.
For now, as a w/a inside def __feature_wct I've changed the condtion if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: continue
to
if cont_mask[0].size <= 1 or styl_mask[0].size <= 1: continue
to ignore labels that causing this issue.
Any idea why it happens and what is the best approach to fix it?