omerbt / Text2LIVE

Official Pytorch Implementation for "Text2LIVE: Text-Driven Layered Image and Video Editing" (ECCV 2022 Oral)
https://text2live.github.io/
MIT License
878 stars 80 forks source link

what's the funcion of Class "concate" ? #2

Open yueanga opened 2 years ago

yueanga commented 2 years ago

sorry but I wonder what’s the funcion of Class "Concate" in Text2LIVE-main/models/backbone/common.py? thank you so muchhh

class Concat(nn.Module): def init(self, dim, *args): super(Concat, self).init() self.dim = dim

    for idx, module in enumerate(args):
        self.add_module(str(idx), module)

def forward(self, input):
    inputs = []
    for module in self._modules.values():
        inputs.append(module(input))

    inputs_shapes2 = [x.shape[2] for x in inputs]
    inputs_shapes3 = [x.shape[3] for x in inputs]

    if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(
        np.array(inputs_shapes3) == min(inputs_shapes3)
    ):
        inputs_ = inputs
        print("np-all")
    else:
        target_shape2 = min(inputs_shapes2)
        target_shape3 = min(inputs_shapes3)
        print("np-target")
        inputs_ = []
        for inp in inputs:
            diff2 = (inp.size(2) - target_shape2) // 2
            diff3 = (inp.size(3) - target_shape3) // 2

            inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3])

    return torch.cat(inputs_, dim=self.dim)