XuelianCheng / LEAStereo

Hierarchical Neural Architecture Searchfor Deep Stereo Matching (NeurIPS 2020)
MIT License
256 stars 51 forks source link

Cannot use code to train model obtained by NAS #11

Closed lebionick closed 3 years ago

lebionick commented 3 years ago

Hi! I'm using your code to find best model on SceneFlow. I need better latency so I reduced parameters of search to 4 for feature net and 8 for matching net. NAS training worked well (despite OOM on epoch 11). I decoded model (there were also problems because of DataParallel saving) and tried to retrain it. But script crashed on the line: https://github.com/XuelianCheng/LEAStereo/blob/master/retrain/skip_model_3d.py#L148 because of number of channels mismatch. As far as I understand it is not the only case, it is hardcoded to use 12 cells. How can I adapt this code to use different parameters?

XuelianCheng commented 3 years ago

Hi Nikolay,

There is another file named 'new_model_3d.py' without skip-connections. This could probably solve your problem.

lebionick commented 3 years ago

Hi @XuelianCheng Yes, it worked! So model with skip-connections is hardcoded to specific number of cells, but should work better? Is there a reason to adopt this idea to smaller model?

XuelianCheng commented 3 years ago

Yes, skip-connections between cells would give a better result.

lebionick commented 3 years ago

I've written skip connection model for 10 cells structure, here is code for forward method, the rest is same

    def forward(self, x):
        stem0 = self.stem0(x)
        stem1 = self.stem1(stem0)
        out = (stem0, stem1)
        out0 = self.cells[0](out[0], out[1])
        out1 = self.cells[1](out0[0], out0[1])
        out2 = self.cells[2](out1[0], out1[1])
        out3 = self.cells[3](out2[0], out2[1])
        cat3_0 = torch.cat((out0[-1], out3[-1]), 1)
        out3_cat = self.conv1(cat3_0)
        out4 = self.cells[4](out3[0], out3_cat)
        out5 = self.cells[5](out4[0], out4[1])
        out6 = self.cells[6](out5[0], out5[1])
        cat3_6 = torch.cat((out3[-1], out6[-1]), 1)
        out6_cat = self.conv1(cat3_6)
        out7 = self.cells[7](out6[0], out6_cat)
        out8 = self.cells[8](out7[0], out7[1])
        out9 = self.cells[9](out8[0], out8[1])
        last_output = out9[-1]

        d, h, w = x.size()[2], x.size()[3], x.size()[4]
        upsample_6  = nn.Upsample(size=x.size()[2:], mode='trilinear', align_corners=True)
        upsample_12 = nn.Upsample(size=[d//2, h//2, w//2], mode='trilinear', align_corners=True)
        upsample_24 = nn.Upsample(size=[d//4, h//4, w//4], mode='trilinear', align_corners=True)

        if last_output.size()[3] == h:
            mat = self.last_3(last_output)
        elif last_output.size()[3] == h//2:
            mat = self.last_3(upsample_6(self.last_6(last_output)))
        elif last_output.size()[3] == h//4:
            mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(last_output)))))
        elif last_output.size()[3] == h//8:
            mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(upsample_24(self.last_24(last_output)))))))      
        return mat