cherubicXN / hawp

Holistically-Attracted Wireframe Parsing [TPAMI'23] & [CVPR' 20]
MIT License
291 stars 51 forks source link

Predicting the model with rectangular input #7

Closed alwc closed 4 years ago

alwc commented 4 years ago

Hi @cherubicXN

I'm trying to train the wireframe dataset with rectangular input shape, i.e. in hawp.yaml, I set

DATASETS:
  IMAGE:
    HEIGHT: 576
    WIDTH: 448
  TARGET:
    HEIGHT: 144
    WIDTH: 112

but it doesn't seem to be trainable at all. The highest sAP10.0 I'm getting is around ~7. Do you know what shall I be aware of if I want to train my model with rectangular input? Thanks!

cherubicXN commented 4 years ago

Have you tried to modify some constant numbers such as 128 in the following code?

https://github.com/cherubicXN/hawp/blob/0ec5a8beeb3e6080c0030f463ba23d82c4c5a18e/parsing/detector.py#L40

alwc commented 4 years ago

Have you tried to modify some constant numbers such as 128 in the following code?

https://github.com/cherubicXN/hawp/blob/0ec5a8beeb3e6080c0030f463ba23d82c4c5a18e/parsing/detector.py#L40

Yup, I changed the get_junctions function to the following:

def get_junctions(jloc, joff, topk=300, th=0):
    """ NOTE:
    > jloc.shape
    torch.Size([1, 144, 112])
    """
    height, width = jloc.size(1), jloc.size(2)
    jloc = jloc.reshape(-1)
    joff = joff.reshape(2, -1)

    scores, index = torch.topk(jloc, k=topk)
    y = (index / width).float() + torch.gather(joff[1], 0, index) + 0.5
    x = (index % height).float() + torch.gather(joff[0], 0, index) + 0.5

    junctions = torch.stack((x, y)).t()

    return junctions[scores > th], scores[scores > th]

I also trained another model just for the sake of testing with

def get_junctions(jloc, joff, topk=300, th=0):
    """ NOTE:
    > jloc.shape
    torch.Size([1, 144, 112])
    """
    ...
    y = (index / height).float() + torch.gather(joff[1], 0, index) + 0.5
    x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5
    ...

and I'm getting sAP10.0 = ~7 for both models.

alwc commented 4 years ago

Hi @cherubicXN , so I did more experiments and based on what I observed, it seems the losses for dis, jloc, and res look fine but something is not quite right for joff, pos and neg.

image

I also tried to visualize the heatmaps of jloc_nms_nms and res and they look fine too: image image

but the end results look quite bad for non-square input: image

I think there are some assumptions that the inputs have to be square. The models I trained with different square dim (e.g. 576x576, 448x448) have almost identical sAP10 for every epoch. Right now I suspect the assumptions are written in the LOI pooling verification step. It would be grateful if you could provide some insights for me, thanks!

cherubicXN commented 4 years ago

Thanks very much for your efforts. Let me train it to see what happened.

alwc commented 4 years ago

@cherubicXN I think I figured out the problem. If I change the following lines in get_junctions:

 y = (index / width).float() + torch.gather(joff[1], 0, index) + 0.5
 x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5

then the model can be trained properly!

cherubicXN commented 4 years ago

@cherubicXN I think I figured out the problem. If I change the following lines in get_junctions:

 y = (index / width).float() + torch.gather(joff[1], 0, index) + 0.5
 x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5

then the model can be trained properly!

That's awesome. Thanks for your hard-working!