If I'm reading this right, the order of permutations is swapped between the zy permutations and xz permutations. That is, the code snippet claiming to perform axis permutation for zy axes is actually doing the corresponding permutation for xz axes, and vice versa. Since the targets target_zy and target_xz are being extracted in the correct order, this could have an impact while calculating the pixelwise heatmap losses.
https://github.com/anibali/margipose/blob/e96d59187dc17651ab184ca263f9a1a150cfa201/src/margipose/models/margipose_model.py#L94-L97
mid_in has axes in BCHW order. After permuting, mid_out (zy) will be BWHC (henceforth treated as BCHD) and mid_out (xz) will be BHCW (henceforth treated as BCDW).
If I'm reading this right, the order of permutations is swapped between the zy permutations and xz permutations. That is, the code snippet claiming to perform axis permutation for zy axes is actually doing the corresponding permutation for xz axes, and vice versa. Since the targets target_zy and target_xz are being extracted in the correct order, this could have an impact while calculating the pixelwise heatmap losses. https://github.com/anibali/margipose/blob/e96d59187dc17651ab184ca263f9a1a150cfa201/src/margipose/models/margipose_model.py#L94-L97