Open jlian2 opened 4 years ago
`` def forward(self, inputs):
inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # Flatten input flat_input = inputs.view(-1, self._embedding_dim)
`` My unders understanding is: dimension of flat_input should be BHWC*embedding_dim, one dimension seems to be missing? Or you are saying number of channels equal to embedding_dim?
``
def forward(self, inputs):
convert inputs from BCHW -> BHWC
`` My unders understanding is: dimension of flat_input should be BHWC*embedding_dim, one dimension seems to be missing? Or you are saying number of channels equal to embedding_dim?