lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.39k stars 370 forks source link

Implement 3D bias injection #92

Closed danieltudosiu closed 2 years ago

danieltudosiu commented 2 years ago

Hi @lucidrains,

Thanks for this fantastic trove of transformers <3

I am mainly working with VQ-VAEs and in my experience, this paper [1] made constant improvements in the morphological correctness of the samples.

If you need help I can give you some pointers on how to start and what's needed. My coworker and I already got the "Bias" version working with the x-transformer but the "Context" one required too much modification of the Attention to warrant its implementation.

Cheers!

[1] Wu, K., Peng, H., Chen, M., Fu, J. and Chao, H., 2021. Rethinking and improving relative position encoding for vision transformer. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 10033-10041).

lucidrains commented 2 years ago

@danieltudosiu Hi Petru-Daniel and thanks for the kind words

This looks really interesting! Let me try a few experiments against rotary embeddings and T5 relative positional bias and see how it fares before adding it

lucidrains commented 2 years ago

@danieltudosiu which version of the contextual did you and your team have the most success with? (queries, queries + keys, or all qkv?)

danieltudosiu commented 2 years ago

Hi @lucidrains,

Currently, we are using "Bias Mode" which can be seen in Figure 1.a and from my understanding is the Queries + Keys one.

One feature that I think is central to this would be the usage of an Ordering object that projects 2D, 3D to 1D and back. I can provide you with that class if you want, we tested it quite a bit on our side.

Cheers!

danieltudosiu commented 2 years ago

@lucidrains the paper I referenced can easily be generalised to 3D and if you want we can work together on the development of the feature. Let me know if you want me to chip in.

lucidrains commented 2 years ago

@danieltudosiu ohh i see, when you say 3d, do you mean you are working with voxel to voxel attention? if you are only doing the bias mode, you are just doing something like t5 relative positional bias, but extended to 3d?

lucidrains commented 2 years ago

@danieltudosiu basically something like this logic https://github.com/lucidrains/vit-pytorch/blob/c7bb5fc43fda45de09b8abf4e9bb23f2d39a1639/vit_pytorch/regionvit.py#L141 for 2d relative positional bias, but extended to 3d?

danieltudosiu commented 2 years ago

@danieltudosiu ohh i see, when you say 3d, do you mean you are working with voxel to voxel attention? if you are only doing the bias mode, you are just doing something like t5 relative positional bias, but extended to 3d?

I mean that the quantized representation of the VQ-VAE is a tensor of shape [B, 1, X, Y, Z] and we feed the flattened version into the transformer which becomes [B, XYZ] thus requiring a way to map the ijk position to the position in the sequence.

@danieltudosiu basically something like this logic https://github.com/lucidrains/vit-pytorch/blob/c7bb5fc43fda45de09b8abf4e9bb23f2d39a1639/vit_pytorch/regionvit.py#L141 for 2d relative positional bias, but extended to 3d?

From my understanding of your code, yes something similar.

danieltudosiu commented 2 years ago

@lucidrains here is the code that I and my coworker wrote to calculate the bias for 3D.

import math
from enum import Enum
from typing import Union, Tuple

import numpy as np
import torch

from gilbert.gilbert2d import gilbert2d
from gilbert.gilbert3d import gilbert3d

class OrderingType(Enum):
    RASTER_SCAN = "raster_scan"
    S_CURVE = "s_curve"
    RANDOM = "random"
    HILBERT = "hilbert_curve"

class OrderingTransformations(Enum):
    ROTATE_90 = "rotate_90"
    TRANSPOSE = "transpose"
    REFLECT = "reflect"

class Ordering:
    def __init__(
        self,
        ordering_type: str,
        spatial_dims: int,
        dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]],
        reflected_spatial_dims: Union[Tuple[bool, bool], Tuple[bool, bool, bool]],
        transpositions_axes: Union[
            Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]
        ],
        rot90_axes: Union[
            Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]
        ],
        transformation_order: Tuple[str, ...] = (
            OrderingTransformations.TRANSPOSE.value,
            OrderingTransformations.ROTATE_90.value,
            OrderingTransformations.REFLECT.value,
        ),
    ):
        super().__init__()
        self.ordering_type = ordering_type

        assert self.ordering_type in [
            e.value for e in OrderingType
        ], f"ordering_type must be one of the following {[e.value for e in OrderingType]}, but got {self.ordering_type}."

        self.spatial_dims = spatial_dims
        self.dimensions = dimensions

        assert (
            len(dimensions) == self.spatial_dims + 1
        ), f"Dimensions must have length {self.spatial_dims + 1}."

        self.reflected_spatial_dims = reflected_spatial_dims
        self.transpositions_axes = transpositions_axes
        self.rot90_axes = rot90_axes
        if len(set(transformation_order)) != len(transformation_order):
            raise ValueError(
                f"No duplicates are allowed. Received {transformation_order}."
            )

        for transformation in transformation_order:
            if transformation not in [t.value for t in OrderingTransformations]:
                raise ValueError(
                    f"Valid transformations are {[t.value for t in OrderingTransformations]} but received {transformation}."
                )
        self.transformation_order = transformation_order

        self.template = self._create_template()
        self._sequence_ordering = self._create_ordering()
        self._revert_sequence_ordering = np.argsort(self._sequence_ordering)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        x = x[self._sequence_ordering]

        return x

    def get_sequence_ordering(self) -> np.ndarray:
        return self._sequence_ordering

    def get_revert_sequence_ordering(self) -> np.ndarray:
        return self._revert_sequence_ordering

    def _create_ordering(self):
        self.template = self._transform_template()
        order = self._order_template(template=self.template)

        return order

    def _create_template(self) -> np.ndarray:
        spatial_dimensions = self.dimensions[1:]
        template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions)

        return template

    def _transform_template(self) -> np.ndarray:
        for transformation in self.transformation_order:
            if transformation == OrderingTransformations.TRANSPOSE.value:
                self.template = self._transpose_template(template=self.template)
            elif transformation == OrderingTransformations.ROTATE_90.value:
                self.template = self._rot90_template(template=self.template)
            elif transformation == OrderingTransformations.REFLECT.value:
                self.template = self._flip_template(template=self.template)

        return self.template

    def _transpose_template(self, template: np.ndarray) -> np.ndarray:
        for axes in self.transpositions_axes:
            template = np.transpose(template, axes=axes)

        return template

    def _flip_template(self, template: np.ndarray) -> np.ndarray:
        for axis, to_reflect in enumerate(self.reflected_spatial_dims):
            template = np.flip(template, axis=axis) if to_reflect else template

        return template

    def _rot90_template(self, template: np.ndarray) -> np.ndarray:
        for axes in self.rot90_axes:
            template = np.rot90(template, axes=axes)

        return template

    def _order_template(self, template: np.ndarray) -> np.ndarray:
        depths = None
        if self.spatial_dims == 2:
            rows, columns = template.shape[0], template.shape[1]
        else:
            rows, columns, depths = (
                template.shape[0],
                template.shape[1],
                template.shape[2],
            )

        sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths)

        ordering = np.array([template[tuple(e)] for e in sequence])

        return ordering

    @staticmethod
    def raster_scan_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
        idx = []

        for r in range(rows):
            for c in range(cols):
                if depths:
                    for d in range(depths):
                        idx.append((r, c, d))
                else:
                    idx.append((r, c))

        idx = np.array(idx)

        return idx

    @staticmethod
    def s_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
        idx = []

        for r in range(rows):
            col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1)
            for c in col_idx:
                if depths:
                    depth_idx = (
                        range(depths) if c % 2 == 0 else range(depths - 1, -1, -1)
                    )

                    for d in depth_idx:
                        idx.append((r, c, d))
                else:
                    idx.append((r, c))

        idx = np.array(idx)

        return idx

    @staticmethod
    def random_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
        idx = []

        for r in range(rows):
            for c in range(cols):
                if depths:
                    for d in range(depths):
                        idx.append((r, c, d))
                else:
                    idx.append((r, c))

        idx = np.array(idx)
        np.random.shuffle(idx)

        return idx

    @staticmethod
    def hilbert_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
        t = list(gilbert3d(rows, cols, depths) if depths else gilbert2d(rows, cols))
        idx = np.array(t)

        return idx

class RelativeSpatialPositioning(torch.nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        ordering: np.ndarray,
        dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]],
        bucket_values: bool = False,
        bucket_beta: int = 50,
        conditioning_length: int = 0,
    ):
        super().__init__()

        self.spatial_dims = spatial_dims
        self.dimensions = dimensions
        self.ordering = ordering
        self.conditioning_length = conditioning_length
        assert (
            len(dimensions) == self.spatial_dims + 1
        ), f"Dimensions must have length {self.spatial_dims + 1}."
        self.bucket_values = bucket_values
        self.bucket_beta = bucket_beta

        self.dist_array = self._get_distance_array()
        self.quantized_distances, self.num_buckets = self._rp_3d_product_and_quantize()

        self.ordered_distance_matrix = self.reorder()
        self.ordered_distance_matrix = self.account_conditionings()

    def get_pid_array(self):
        return self.ordered_distance_matrix

    def get_num_pids(self):
        return self.num_buckets

    def account_conditionings(self):
        if self.conditioning_length > 0:
            ordered_distance_matrix = (
                torch.ones(
                    self.ordered_distance_matrix.shape[0] + self.conditioning_length,
                    self.ordered_distance_matrix.shape[0] + self.conditioning_length,
                    dtype=self.ordered_distance_matrix.dtype,
                    device=self.ordered_distance_matrix.device,
                )
                * self.num_buckets
            )

            ordered_distance_matrix[
                self.conditioning_length :, self.conditioning_length :
            ] = self.ordered_distance_matrix

            self.num_buckets += 1

            return ordered_distance_matrix

        return self.ordered_distance_matrix

    def reorder(self):
        pid_rel_pos = self.quantized_distances.reshape(
            self.dimensions[1] * self.dimensions[2] * self.dimensions[3], -1
        )

        dim_1_reordered = torch.zeros_like(pid_rel_pos)
        for i in range(len(self.ordering)):
            dim_1_reordered[i] = pid_rel_pos[self.ordering[i]]

        dim_2_reordered = torch.zeros_like(pid_rel_pos)
        for i in range(len(self.ordering)):
            dim_2_reordered[:, i] = dim_1_reordered[:, self.ordering[i]]

        return dim_2_reordered

    def _get_distance_array(self):
        coord_array = torch.zeros(
            (self.dimensions[1], self.dimensions[2], self.dimensions[3], 3),
            dtype=torch.int,
        )
        height = coord_array.shape[0]
        width = coord_array.shape[1]
        depth = coord_array.shape[2]

        for i in range(height):
            for j in range(width):
                for k in range(depth):
                    coord_array[i, j, k, 0] = i
                    coord_array[i, j, k, 1] = j
                    coord_array[i, j, k, 2] = k

        dist_array = torch.zeros(
            (height, width, depth, height, width, depth, 3), dtype=torch.int
        )
        coord_array_widths = coord_array[:, :, :, 1]
        coord_array_heights = coord_array[:, :, :, 0]
        coord_array_depths = coord_array[:, :, :, 2]

        for i in range(height):
            for j in range(width):
                for k in range(depth):
                    dist_array[i, j, k, :, :, :, 0] = coord_array_heights - i
                    dist_array[i, j, k, :, :, :, 1] = coord_array_widths - j
                    dist_array[i, j, k, :, :, :, 2] = coord_array_depths - k

        return dist_array

    # Code adapted from iRPE in 2D:
    # https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DETR-with-iRPE/models/rpe_attention/irpe.py#L19
    def _rp_3d_product_and_quantize(self):

        alpha = self.bucket_beta / 2
        gamma = self.bucket_beta * 4

        if self.bucket_values:
            r = (
                self.piecewise_index(
                    self.dist_array[:, :, :, :, :, :, 0], alpha, self.bucket_beta, gamma
                )
                + self.bucket_beta
            )
            c = (
                self.piecewise_index(
                    self.dist_array[:, :, :, :, :, :, 1], alpha, self.bucket_beta, gamma
                )
                + self.bucket_beta
            )
            d = (
                self.piecewise_index(
                    self.dist_array[:, :, :, :, :, :, 2], alpha, self.bucket_beta, gamma
                )
                + self.bucket_beta
            )
        else:
            r = self.dist_array[:, :, :, :, :, :, 0]
            c = self.dist_array[:, :, :, :, :, :, 1]
            d = self.dist_array[:, :, :, :, :, :, 2]

        r = r - torch.min(r)
        c = c - torch.min(c)
        d = d - torch.min(d)

        max_dim = max(torch.max(r), torch.max(c), torch.max(d)) + 1

        pid = r + (c * max_dim) + (d * max_dim ** 2)

        return pid, torch.max(pid)

    @staticmethod
    def piecewise_index(relative_position, alpha, beta, gamma, dtype=torch.int):
        """piecewise index function
        Parameters
        ----------
        relative_position: torch.Tensor, dtype: long or float
            The shape of `relative_position` is (L, L).
        alpha, beta, gamma: float
            The coefficients of piecewise index function.
        Returns
        -------
        idx: torch.Tensor, dtype: long
            A tensor indexing relative distances to corresponding encodings.
            `idx` is a long tensor, whose shape is (L, L) and each element is in [-beta, beta].
        """
        rp_abs = relative_position.abs()
        mask = rp_abs <= alpha
        not_mask = ~mask
        rp_out = relative_position[not_mask]
        rp_abs_out = rp_abs[not_mask]
        y_out = (
            torch.sign(rp_out)
            * (
                alpha
                + torch.log(rp_abs_out / alpha)
                / math.log(gamma / alpha)
                * (beta - alpha)
            )
            .round()
            .clip(max=beta)
        ).to(dtype)

        idx = relative_position.clone()
        if idx.dtype in [torch.float32, torch.float64]:
            # round(x) when |x| <= alpha
            idx = idx.round().to(dtype)

        # assign the value when |x| > alpha
        idx[not_mask] = y_out
        return idx
lucidrains commented 2 years ago

@danieltudosiu yeah it is the same

i think relative positional bias, especially generalizing it to N-dimensions, may have to belong in a separate repository

lucidrains commented 2 years ago

@danieltudosiu are you working with video?

danieltudosiu commented 2 years ago

@danieltudosiu yeah it is the same

i think relative positional bias, especially generalizing it to N-dimensions, may have to belong in a separate repository

Nice, yes, that would be best, but it is quite niche, mainly useful in ensemble anomaly detections.

@danieltudosiu are you working with video?

Volumetric medical imaging

lucidrains commented 2 years ago

@danieltudosiu ohh got it

yea, let me think about how to tackle this in another repository then, for both space (3d) and potentially even space + time (4d) - don't think we'll have to go beyond that

i've been tired of redoing the code for 2d relative positional bias across repositories anyways, so probably a good time to generalize it haha

danieltudosiu commented 2 years ago

@danieltudosiu ohh got it

yea, let me think about how to tackle this in another repository then, for both space (3d) and potentially even space + time (4d) - don't think we'll have to go beyond that

i've been tired of redoing the code for 2d relative positional bias across repositories anyways, so probably a good time to generalize it haha

I do not know how useful the next idea would be but I would start by decomposing the attention mechanism in all its additions and write some kind of wrapper let's call it "Attention Enchancer" that calls those addition methods one by one and modifies their input/output :-?

danieltudosiu commented 2 years ago

Think of it like the Figure 1 from the paper I referenced in the issue.

lucidrains commented 2 years ago

@danieltudosiu ohh got it yea, let me think about how to tackle this in another repository then, for both space (3d) and potentially even space + time (4d) - don't think we'll have to go beyond that i've been tired of redoing the code for 2d relative positional bias across repositories anyways, so probably a good time to generalize it haha

I do not know how useful the next idea would be but I would start by decomposing the attention mechanism in all its additions and write some kind of wrapper let's call it "Attention Enchancer" that calls those addition methods one by one and modifies their input/output :-?

do you mean something like https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L1259 ?

btw, have you ever tried the contextual bias that was presented in the paper? i'm curious how that performs

danieltudosiu commented 2 years ago

@lucidrains I don't follow in regards to DALLE 2 :-? could you help me understand

We didn't try the contextual bias since it showed marginal improvement compared to the bias one and deemed it unnecessary engineering time.

lucidrains commented 2 years ago

ohh thanks for letting me know! yeah I guess static bias is good enough 🤔

ok maybe I'll build out the relative positional bias library when I find some time, as a separate repo, and with all the features like max cutoffs etc