Closed hildazzz closed 12 months ago
x_ = torch.cat([x0, x1], 1) logdet = torch.sum(logs, [1, 2]) return x, logdet
Hi, I completely misread your question earlier and you are right, it is a typo. It should be x instead of x_ in the first line. Thanks!!
x
x_