Closed jjeremy40 closed 1 year ago
Hi ! The code to compute the bits allocation map of the maximal entropy channel(selected by scale) is: “ likelihood = self._likelihood(y, scale, means) allocated_bits = torch.log(likelihood) / - math.log(2) ” with "self._likelihood" defined here.
Hi !, Thanks for answering so quickly.
I'm not sure I understand... Would it be something like that ? :
def bit_map_alloc(self, x, scales, means=None) :
# same as def forward in SymmetricalTransFormer
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
for i in range(self.num_layers):
layer = self.layers[i]
x, Wh, Ww = layer(x, Wh, Ww)
y = x
C = self.embed_dim * 8
y = y.view(-1, Wh, Ww, C).permute(0, 3, 1, 2).contiguous()
y_shape = y.shape[2:]
z = self.h_a(y)
_, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
latent_scales = self.h_scale_s(z_hat)
latent_means = self.h_mean_s(z_hat)
y_slices = y.chunk(self.num_slices, 1)
y_hat_slices = []
y_likelihood = []
for slice_index, y_slice in enumerate(y_slices):
support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
mean_support = torch.cat([latent_means] + support_slices, dim=1)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :, :y_shape[0], :y_shape[1]]
scale_support = torch.cat([latent_scales] + support_slices, dim=1)
scale = self.cc_scale_transforms[slice_index](scale_support)
scale = scale[:, :, :y_shape[0], :y_shape[1]]
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
y_likelihood.append(y_slice_likelihood)
y_hat_slice = ste_round(y_slice - mu) + mu
lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * torch.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
y_hat = torch.cat(y_hat_slices, dim=1)
y_likelihoods = torch.cat(y_likelihood, dim=1)
y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh*Ww, C)
for i in range(self.num_layers):
layer = self.syn_layers[i]
y_hat, Wh, Ww = layer(y_hat, Wh, Ww)
likelihood = self._likelihood(y_hat, scales)
allocated_bits = torch.log(likelihood) / -math.log(2)
return allocated_bits
def _likelihood(self, inputs, scales, means=None):
half = float(0.5)
if means is not None:
values = inputs - means
else:
values = inputs
scales = torch.max(scales, torch.tensor(0.11))
values = torch.abs(values)
upper = self._standardized_cumulative((half - values) / scales)
lower = self._standardized_cumulative((-half - values) / scales)
likelihood = upper - lower
return likelihood
def _standardized_cumulative(self, inputs):
half = float(0.5)
const = float(-(2 ** -0.5))
# Using the complementary error function maximizes numerical precision.
return half * torch.erfc(const * inputs)
Not exactly, if you want to acquire the bit_map_alloc in the 'forward', actually it should be like this:
def bit_map_alloc(self, x, scales, means=None) :
# same as def forward in SymmetricalTransFormer
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
for i in range(self.num_layers):
layer = self.layers[i]
x, Wh, Ww = layer(x, Wh, Ww)
y = x
C = self.embed_dim * 8
y = y.view(-1, Wh, Ww, C).permute(0, 3, 1, 2).contiguous()
y_shape = y.shape[2:]
z = self.h_a(y)
_, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
latent_scales = self.h_scale_s(z_hat)
latent_means = self.h_mean_s(z_hat)
y_slices = y.chunk(self.num_slices, 1)
y_hat_slices = []
y_likelihood = []
for slice_index, y_slice in enumerate(y_slices):
support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
mean_support = torch.cat([latent_means] + support_slices, dim=1)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :, :y_shape[0], :y_shape[1]]
scale_support = torch.cat([latent_scales] + support_slices, dim=1)
scale = self.cc_scale_transforms[slice_index](scale_support)
scale = scale[:, :, :y_shape[0], :y_shape[1]]
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
y_likelihood.append(y_slice_likelihood)
y_hat_slice = ste_round(y_slice - mu) + mu
lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * torch.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
y_hat = torch.cat(y_hat_slices, dim=1)
# here the 'likelihoods' have been computed
y_likelihoods = torch.cat(y_likelihood, dim=1)
allocated_bits = torch.log(likelihood) / - math.log(2)
return allocated_bits
Here we get the allocated bits (or entropy) of the entire ‘y_hat' . To visualize it, we choose the maximal entropy channel, so there remains extra post-processing to do.
Ho ok ! Indeed, I get allocated_bits= torch.Size([1, 384, 48, 80])
So, if I understand correctly, we selecte the max entropy channel out of the 384 channels and then, do I resize that channel to the original image size ?
Yes, and directly visualizing the channel (torch.Size([48, 80])) without resizing.
Ok got it !
Thank you so much !!
Hi !
Great work !
I have just one question : how did you compute the bits allocation map in Fig 3 in your article ?
Thanks !