Closed lebionick closed 3 years ago
Hi Nikolay,
There is another file named 'new_model_3d.py' without skip-connections. This could probably solve your problem.
Hi @XuelianCheng Yes, it worked! So model with skip-connections is hardcoded to specific number of cells, but should work better? Is there a reason to adopt this idea to smaller model?
Yes, skip-connections between cells would give a better result.
I've written skip connection model for 10 cells structure, here is code for forward
method, the rest is same
def forward(self, x):
stem0 = self.stem0(x)
stem1 = self.stem1(stem0)
out = (stem0, stem1)
out0 = self.cells[0](out[0], out[1])
out1 = self.cells[1](out0[0], out0[1])
out2 = self.cells[2](out1[0], out1[1])
out3 = self.cells[3](out2[0], out2[1])
cat3_0 = torch.cat((out0[-1], out3[-1]), 1)
out3_cat = self.conv1(cat3_0)
out4 = self.cells[4](out3[0], out3_cat)
out5 = self.cells[5](out4[0], out4[1])
out6 = self.cells[6](out5[0], out5[1])
cat3_6 = torch.cat((out3[-1], out6[-1]), 1)
out6_cat = self.conv1(cat3_6)
out7 = self.cells[7](out6[0], out6_cat)
out8 = self.cells[8](out7[0], out7[1])
out9 = self.cells[9](out8[0], out8[1])
last_output = out9[-1]
d, h, w = x.size()[2], x.size()[3], x.size()[4]
upsample_6 = nn.Upsample(size=x.size()[2:], mode='trilinear', align_corners=True)
upsample_12 = nn.Upsample(size=[d//2, h//2, w//2], mode='trilinear', align_corners=True)
upsample_24 = nn.Upsample(size=[d//4, h//4, w//4], mode='trilinear', align_corners=True)
if last_output.size()[3] == h:
mat = self.last_3(last_output)
elif last_output.size()[3] == h//2:
mat = self.last_3(upsample_6(self.last_6(last_output)))
elif last_output.size()[3] == h//4:
mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(last_output)))))
elif last_output.size()[3] == h//8:
mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(upsample_24(self.last_24(last_output)))))))
return mat
Hi! I'm using your code to find best model on SceneFlow. I need better latency so I reduced parameters of search to 4 for feature net and 8 for matching net. NAS training worked well (despite OOM on epoch 11). I decoded model (there were also problems because of DataParallel saving) and tried to retrain it. But script crashed on the line: https://github.com/XuelianCheng/LEAStereo/blob/master/retrain/skip_model_3d.py#L148 because of number of channels mismatch. As far as I understand it is not the only case, it is hardcoded to use 12 cells. How can I adapt this code to use different parameters?