Open yueanga opened 2 years ago
sorry but I wonder what’s the funcion of Class "Concate" in Text2LIVE-main/models/backbone/common.py? thank you so muchhh
class Concat(nn.Module): def init(self, dim, *args): super(Concat, self).init() self.dim = dim
for idx, module in enumerate(args): self.add_module(str(idx), module) def forward(self, input): inputs = [] for module in self._modules.values(): inputs.append(module(input)) inputs_shapes2 = [x.shape[2] for x in inputs] inputs_shapes3 = [x.shape[3] for x in inputs] if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( np.array(inputs_shapes3) == min(inputs_shapes3) ): inputs_ = inputs print("np-all") else: target_shape2 = min(inputs_shapes2) target_shape3 = min(inputs_shapes3) print("np-target") inputs_ = [] for inp in inputs: diff2 = (inp.size(2) - target_shape2) // 2 diff3 = (inp.size(3) - target_shape3) // 2 inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3]) return torch.cat(inputs_, dim=self.dim)
sorry but I wonder what’s the funcion of Class "Concate" in Text2LIVE-main/models/backbone/common.py? thank you so muchhh