lanpa / tensorboardX

tensorboard for pytorch (and chainer, mxnet, numpy, ...)
https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
MIT License
7.87k stars 864 forks source link

add_graph error with SiameseNet #295

Open Tmcsn opened 5 years ago

Tmcsn commented 5 years ago

class Siamese(nn.Module):

def __init__(self):
    super(Siamese, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(1, 64, 10),  # 64@96*96
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2),  # 64@48*48
        nn.Conv2d(64, 128, 7),
        nn.ReLU(),    # 128@42*42
        nn.MaxPool2d(2),   # 128@21*21
        nn.Conv2d(128, 128, 4),
        nn.ReLU(), # 128@18*18
        nn.MaxPool2d(2), # 128@9*9
        nn.Conv2d(128, 256, 4),
        nn.ReLU(),   # 256@6*6
    )
    self.liner = nn.Sequential(nn.Linear(9216, 4096), nn.Sigmoid())
    self.out = nn.Linear(4096, 1)

def forward_one(self, x):
    x = self.conv(x)
    x = x.view(x.size()[0], -1)
    x = self.liner(x)
    return x

def forward(self, x1, x2):
    out1 = self.forward_one(x1)
    out2 = self.forward_one(x2)
    dis = torch.abs(out1 - out2)
    out = self.out(dis)
    #  return self.sigmoid(out)
    return out

—————————————— si=Siamese() img=torch.rand(1,3,96,96) img2=torch.rand(1,3,96,96) writer=SummaryWriter(comment='bcnn') writer.add_graph(si(img,img2),(img,img2))

error occured: /opt/conda/conda-bld/pytorch_1532581333611/work/torch/csrc/jit/tracer.h:143: getTracingState: Assertion var_state == state failed.

lanpa commented 5 years ago

Seem that si(img,img2) can not forward correctly. Can you check the code again? thanks. (I test your code pytorch 1.0.0)

RuntimeError: Given groups=1, weight of size [64, 1, 10, 10], expected input[1, 3, 96, 96] to have 1 channels, but got 3 channels instead