Last time I got a crop of input image to show the grount truth and x_hat. But I want to see the complete image. So I tried to adapt the networks.py to inference.
In class NetHigh() encode part, I met some obstacles. The x_hat is wrong.
In def compresslow()
for block in tqdm.tqdm(blocks): with torch.no_grad(): block = block.to(device) ret = net.encode(block)
`
def encode(self, inputs):
b, c, h, w = inputs.shape
print("inputs shape is ", inputs.shape)
Last time I got a crop of input image to show the grount truth and x_hat. But I want to see the complete image. So I tried to adapt the
networks.py
to inference. Inclass NetHigh()
encode part, I met some obstacles. The x_hat is wrong.In
def compresslow()
for block in tqdm.tqdm(blocks): with torch.no_grad(): block = block.to(device) ret = net.encode(block)
` def encode(self, inputs): b, c, h, w = inputs.shape print("inputs shape is ", inputs.shape)
print("inputs is ", inputs)
` I return the gt, x_hat, x_tilde, pf, h2, h1 to display them.
The gt is normal like this
But the x_hat and x_tilde is wrong. This is x_tilde, and the latter is x_hat
x_tilde = self.side_recon_model(pf, h2, h1)
I print some information concerned with pf, h2 and h1.The shape of pf is (1, 384, 384, 384)
I am confused that why x_tilde and x_hat can't reconstruct the input. Kindly give me your advice.