function2-llx / PUMIT

33 stars 2 forks source link

Need guidance on using pretrained tokenizer #8

Open Masaaki-75 opened 6 months ago

Masaaki-75 commented 6 months ago

Hi! I am trying to use the pretrained tokenizer to obtain latent code for my input CT images.

However, I didn't see the identity-mapping-like reconstruction as demonstrated in Figure 3 of your paper. I guess there's something wrong with the way I handle input.

Here's the process:

"""Step 1: Define the network"""
quantize = VectorQuantizer(num_embeddings=1024, embedding_dim=512, mode='soft')
tokenizer = SimpleVQTokenizer(quantize=quantize, in_channels=3, start_stride=4)
tokenizer.load_state_dict(ckpt['model'], strict=True)
tokenizer.eval()

"""Step 2: Prepare the input"""
def get_rescaled_ct(npy_path, new_range=(-1, 1)):
    x = np.load(npy_path).clip(-1024, 3071)  # typical CT range
    x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
    # linearly transforms (x_min, x_max) to (y_min, y_max)
    x = rescale_tensor(x, y_min=new_range[0], y_max=new_range[1], x_min=-1024, x_max=3071)
    return x

def prepare_input(x: torch.Tensor):
    if x.ndim == 4:  # 2D -> 3D
        x = x.unsqueeze(2)
    if x.shape[1] == 1:  # 1-channel -> 3-channel
        x = x.repeat(1, 3, 1, 1, 1)

    if not isinstance(x, SpatialTensor):
        # force aniso_d=6 for 2D input
        x = SpatialTensor(x, aniso_d=6)  
    return x

"""Step 3: Test the tokenizer"""
img_path = ".../some_ct_slice.npy"
x0 = get_rescaled_ct(img_path)  # [1, 1, 512, 512], ranging within [-1, 1]
x = prepare_input(x0)  # [1, 3, 1, 512, 512], ranging within [-1, 1]

with torch.no_grad():
    z = tokenizer.encode(x)
    y = tokenizer.decode(z)

I was expecting that y looks similar as x, but the visualization shows: image

Any advice on that? Thanks!

BTW, here's the info about x0, x, z and y, if needed:

x0: Shape: (1, 1, 512, 512), Range: (-1., 0.4971).
x: Shape: (1, 3, 1, 512, 512), Range: (-1., 0.4971).
z: Shape: (1, 512, 1, 32, 32), Range: (-0.2395, 17.5686).
y: Shape: (1, 3, 1, 512, 512), Range: (-1.6264, 5.3129).
function2-llx commented 5 months ago

Hello, sorry for making you wait for so long, since we are working on other stuffs. Did you solve this issue? I guess this may be caused by the code version mismatch. Which version of code are you using?

Masaaki-75 commented 5 months ago

I am not sure about the exact version. I guess it would be from submit branch in January but seems like it is gone now. Here's what I can confirm:

Also, the detailed architecture of SimpleVQTokenizer is as follows, if this will help:

SimpleVQTokenizer(
  (quantize): VectorQuantizer(
    (proj): Linear(in_features=512, out_features=1024, bias=True)
    (embedding): Embedding(1024, 512)
  )
  (encoder): Sequential(
    (0): InflatableConv3d(3, 128, kernel_size=(4, 4, 4), stride=(4, 4, 4))
    (1): LayerNormNd(
      (0): ChannelLast('n c ... -> n ... c')
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ChannelFirst('n ... c -> n c ...')
      (3): Contiguous()
    )
    (2): GELU(approximate='none')
    (3): InflatableConv3d(128, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (4): LayerNormNd(
      (0): ChannelLast('n c ... -> n ... c')
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): ChannelFirst('n ... c -> n c ...')
      (3): Contiguous()
    )
    (5): GELU(approximate='none')
    (6): InflatableConv3d(256, 512, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (7): LayerNormNd(
      (0): ChannelLast('n c ... -> n ... c')
      (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (2): ChannelFirst('n ... c -> n c ...')
      (3): Contiguous()
    )
    (8): GELU(approximate='none')
    (9): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (10): GroupNorm(8, 512, eps=1e-05, affine=True)
    (11): LeakyReLU(negative_slope=0.01, inplace=True)
    (12): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (13): GroupNorm(8, 512, eps=1e-05, affine=True)
    (14): LeakyReLU(negative_slope=0.01, inplace=True)
  )
  (decoder): Sequential(
    (0): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): GroupNorm(8, 512, eps=1e-05, affine=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): InflatableConv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): GroupNorm(8, 512, eps=1e-05, affine=True)
    (5): LeakyReLU(negative_slope=0.01, inplace=True)
    (6): AdaptiveTransposedConvUpsample(
      (transposed_conv): InflatableTransposedConv3d(512, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2))
      (conv): Sequential(
        (0): InflatableConv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): GroupNorm(8, 256, eps=1e-05, affine=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (7): AdaptiveTransposedConvUpsample(
      (transposed_conv): InflatableTransposedConv3d(256, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2))
      (conv): Sequential(
        (0): InflatableConv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): GroupNorm(8, 128, eps=1e-05, affine=True)
        (2): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (8): InflatableTransposedConv3d(128, 3, kernel_size=(4, 4, 4), stride=(4, 4, 4))
  )
)
function2-llx commented 3 months ago

@Masaaki-75 My dear friend, you forgot to perform the vector quantization. You should call tokenizer.quantize(z) before decoding. Sorry again for the late reply.