Closed danieltudosiu closed 2 years ago
@danieltudosiu Hi Petru-Daniel and thanks for the kind words
This looks really interesting! Let me try a few experiments against rotary embeddings and T5 relative positional bias and see how it fares before adding it
@danieltudosiu which version of the contextual did you and your team have the most success with? (queries, queries + keys, or all qkv?)
Hi @lucidrains,
Currently, we are using "Bias Mode" which can be seen in Figure 1.a and from my understanding is the Queries + Keys one.
One feature that I think is central to this would be the usage of an Ordering object that projects 2D, 3D to 1D and back. I can provide you with that class if you want, we tested it quite a bit on our side.
Cheers!
@lucidrains the paper I referenced can easily be generalised to 3D and if you want we can work together on the development of the feature. Let me know if you want me to chip in.
@danieltudosiu ohh i see, when you say 3d, do you mean you are working with voxel to voxel attention? if you are only doing the bias mode, you are just doing something like t5 relative positional bias, but extended to 3d?
@danieltudosiu basically something like this logic https://github.com/lucidrains/vit-pytorch/blob/c7bb5fc43fda45de09b8abf4e9bb23f2d39a1639/vit_pytorch/regionvit.py#L141 for 2d relative positional bias, but extended to 3d?
@danieltudosiu ohh i see, when you say 3d, do you mean you are working with voxel to voxel attention? if you are only doing the bias mode, you are just doing something like t5 relative positional bias, but extended to 3d?
I mean that the quantized representation of the VQ-VAE is a tensor of shape [B, 1, X, Y, Z] and we feed the flattened version into the transformer which becomes [B, XYZ] thus requiring a way to map the ijk position to the position in the sequence.
@danieltudosiu basically something like this logic https://github.com/lucidrains/vit-pytorch/blob/c7bb5fc43fda45de09b8abf4e9bb23f2d39a1639/vit_pytorch/regionvit.py#L141 for 2d relative positional bias, but extended to 3d?
From my understanding of your code, yes something similar.
@lucidrains here is the code that I and my coworker wrote to calculate the bias for 3D.
import math
from enum import Enum
from typing import Union, Tuple
import numpy as np
import torch
from gilbert.gilbert2d import gilbert2d
from gilbert.gilbert3d import gilbert3d
class OrderingType(Enum):
RASTER_SCAN = "raster_scan"
S_CURVE = "s_curve"
RANDOM = "random"
HILBERT = "hilbert_curve"
class OrderingTransformations(Enum):
ROTATE_90 = "rotate_90"
TRANSPOSE = "transpose"
REFLECT = "reflect"
class Ordering:
def __init__(
self,
ordering_type: str,
spatial_dims: int,
dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]],
reflected_spatial_dims: Union[Tuple[bool, bool], Tuple[bool, bool, bool]],
transpositions_axes: Union[
Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]
],
rot90_axes: Union[
Tuple[Tuple[int, int], ...], Tuple[Tuple[int, int, int], ...]
],
transformation_order: Tuple[str, ...] = (
OrderingTransformations.TRANSPOSE.value,
OrderingTransformations.ROTATE_90.value,
OrderingTransformations.REFLECT.value,
),
):
super().__init__()
self.ordering_type = ordering_type
assert self.ordering_type in [
e.value for e in OrderingType
], f"ordering_type must be one of the following {[e.value for e in OrderingType]}, but got {self.ordering_type}."
self.spatial_dims = spatial_dims
self.dimensions = dimensions
assert (
len(dimensions) == self.spatial_dims + 1
), f"Dimensions must have length {self.spatial_dims + 1}."
self.reflected_spatial_dims = reflected_spatial_dims
self.transpositions_axes = transpositions_axes
self.rot90_axes = rot90_axes
if len(set(transformation_order)) != len(transformation_order):
raise ValueError(
f"No duplicates are allowed. Received {transformation_order}."
)
for transformation in transformation_order:
if transformation not in [t.value for t in OrderingTransformations]:
raise ValueError(
f"Valid transformations are {[t.value for t in OrderingTransformations]} but received {transformation}."
)
self.transformation_order = transformation_order
self.template = self._create_template()
self._sequence_ordering = self._create_ordering()
self._revert_sequence_ordering = np.argsort(self._sequence_ordering)
def __call__(self, x: torch.Tensor) -> torch.Tensor:
x = x[self._sequence_ordering]
return x
def get_sequence_ordering(self) -> np.ndarray:
return self._sequence_ordering
def get_revert_sequence_ordering(self) -> np.ndarray:
return self._revert_sequence_ordering
def _create_ordering(self):
self.template = self._transform_template()
order = self._order_template(template=self.template)
return order
def _create_template(self) -> np.ndarray:
spatial_dimensions = self.dimensions[1:]
template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions)
return template
def _transform_template(self) -> np.ndarray:
for transformation in self.transformation_order:
if transformation == OrderingTransformations.TRANSPOSE.value:
self.template = self._transpose_template(template=self.template)
elif transformation == OrderingTransformations.ROTATE_90.value:
self.template = self._rot90_template(template=self.template)
elif transformation == OrderingTransformations.REFLECT.value:
self.template = self._flip_template(template=self.template)
return self.template
def _transpose_template(self, template: np.ndarray) -> np.ndarray:
for axes in self.transpositions_axes:
template = np.transpose(template, axes=axes)
return template
def _flip_template(self, template: np.ndarray) -> np.ndarray:
for axis, to_reflect in enumerate(self.reflected_spatial_dims):
template = np.flip(template, axis=axis) if to_reflect else template
return template
def _rot90_template(self, template: np.ndarray) -> np.ndarray:
for axes in self.rot90_axes:
template = np.rot90(template, axes=axes)
return template
def _order_template(self, template: np.ndarray) -> np.ndarray:
depths = None
if self.spatial_dims == 2:
rows, columns = template.shape[0], template.shape[1]
else:
rows, columns, depths = (
template.shape[0],
template.shape[1],
template.shape[2],
)
sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths)
ordering = np.array([template[tuple(e)] for e in sequence])
return ordering
@staticmethod
def raster_scan_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
idx = []
for r in range(rows):
for c in range(cols):
if depths:
for d in range(depths):
idx.append((r, c, d))
else:
idx.append((r, c))
idx = np.array(idx)
return idx
@staticmethod
def s_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
idx = []
for r in range(rows):
col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1)
for c in col_idx:
if depths:
depth_idx = (
range(depths) if c % 2 == 0 else range(depths - 1, -1, -1)
)
for d in depth_idx:
idx.append((r, c, d))
else:
idx.append((r, c))
idx = np.array(idx)
return idx
@staticmethod
def random_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
idx = []
for r in range(rows):
for c in range(cols):
if depths:
for d in range(depths):
idx.append((r, c, d))
else:
idx.append((r, c))
idx = np.array(idx)
np.random.shuffle(idx)
return idx
@staticmethod
def hilbert_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray:
t = list(gilbert3d(rows, cols, depths) if depths else gilbert2d(rows, cols))
idx = np.array(t)
return idx
class RelativeSpatialPositioning(torch.nn.Module):
def __init__(
self,
spatial_dims: int,
ordering: np.ndarray,
dimensions: Union[Tuple[int, int, int], Tuple[int, int, int, int]],
bucket_values: bool = False,
bucket_beta: int = 50,
conditioning_length: int = 0,
):
super().__init__()
self.spatial_dims = spatial_dims
self.dimensions = dimensions
self.ordering = ordering
self.conditioning_length = conditioning_length
assert (
len(dimensions) == self.spatial_dims + 1
), f"Dimensions must have length {self.spatial_dims + 1}."
self.bucket_values = bucket_values
self.bucket_beta = bucket_beta
self.dist_array = self._get_distance_array()
self.quantized_distances, self.num_buckets = self._rp_3d_product_and_quantize()
self.ordered_distance_matrix = self.reorder()
self.ordered_distance_matrix = self.account_conditionings()
def get_pid_array(self):
return self.ordered_distance_matrix
def get_num_pids(self):
return self.num_buckets
def account_conditionings(self):
if self.conditioning_length > 0:
ordered_distance_matrix = (
torch.ones(
self.ordered_distance_matrix.shape[0] + self.conditioning_length,
self.ordered_distance_matrix.shape[0] + self.conditioning_length,
dtype=self.ordered_distance_matrix.dtype,
device=self.ordered_distance_matrix.device,
)
* self.num_buckets
)
ordered_distance_matrix[
self.conditioning_length :, self.conditioning_length :
] = self.ordered_distance_matrix
self.num_buckets += 1
return ordered_distance_matrix
return self.ordered_distance_matrix
def reorder(self):
pid_rel_pos = self.quantized_distances.reshape(
self.dimensions[1] * self.dimensions[2] * self.dimensions[3], -1
)
dim_1_reordered = torch.zeros_like(pid_rel_pos)
for i in range(len(self.ordering)):
dim_1_reordered[i] = pid_rel_pos[self.ordering[i]]
dim_2_reordered = torch.zeros_like(pid_rel_pos)
for i in range(len(self.ordering)):
dim_2_reordered[:, i] = dim_1_reordered[:, self.ordering[i]]
return dim_2_reordered
def _get_distance_array(self):
coord_array = torch.zeros(
(self.dimensions[1], self.dimensions[2], self.dimensions[3], 3),
dtype=torch.int,
)
height = coord_array.shape[0]
width = coord_array.shape[1]
depth = coord_array.shape[2]
for i in range(height):
for j in range(width):
for k in range(depth):
coord_array[i, j, k, 0] = i
coord_array[i, j, k, 1] = j
coord_array[i, j, k, 2] = k
dist_array = torch.zeros(
(height, width, depth, height, width, depth, 3), dtype=torch.int
)
coord_array_widths = coord_array[:, :, :, 1]
coord_array_heights = coord_array[:, :, :, 0]
coord_array_depths = coord_array[:, :, :, 2]
for i in range(height):
for j in range(width):
for k in range(depth):
dist_array[i, j, k, :, :, :, 0] = coord_array_heights - i
dist_array[i, j, k, :, :, :, 1] = coord_array_widths - j
dist_array[i, j, k, :, :, :, 2] = coord_array_depths - k
return dist_array
# Code adapted from iRPE in 2D:
# https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DETR-with-iRPE/models/rpe_attention/irpe.py#L19
def _rp_3d_product_and_quantize(self):
alpha = self.bucket_beta / 2
gamma = self.bucket_beta * 4
if self.bucket_values:
r = (
self.piecewise_index(
self.dist_array[:, :, :, :, :, :, 0], alpha, self.bucket_beta, gamma
)
+ self.bucket_beta
)
c = (
self.piecewise_index(
self.dist_array[:, :, :, :, :, :, 1], alpha, self.bucket_beta, gamma
)
+ self.bucket_beta
)
d = (
self.piecewise_index(
self.dist_array[:, :, :, :, :, :, 2], alpha, self.bucket_beta, gamma
)
+ self.bucket_beta
)
else:
r = self.dist_array[:, :, :, :, :, :, 0]
c = self.dist_array[:, :, :, :, :, :, 1]
d = self.dist_array[:, :, :, :, :, :, 2]
r = r - torch.min(r)
c = c - torch.min(c)
d = d - torch.min(d)
max_dim = max(torch.max(r), torch.max(c), torch.max(d)) + 1
pid = r + (c * max_dim) + (d * max_dim ** 2)
return pid, torch.max(pid)
@staticmethod
def piecewise_index(relative_position, alpha, beta, gamma, dtype=torch.int):
"""piecewise index function
Parameters
----------
relative_position: torch.Tensor, dtype: long or float
The shape of `relative_position` is (L, L).
alpha, beta, gamma: float
The coefficients of piecewise index function.
Returns
-------
idx: torch.Tensor, dtype: long
A tensor indexing relative distances to corresponding encodings.
`idx` is a long tensor, whose shape is (L, L) and each element is in [-beta, beta].
"""
rp_abs = relative_position.abs()
mask = rp_abs <= alpha
not_mask = ~mask
rp_out = relative_position[not_mask]
rp_abs_out = rp_abs[not_mask]
y_out = (
torch.sign(rp_out)
* (
alpha
+ torch.log(rp_abs_out / alpha)
/ math.log(gamma / alpha)
* (beta - alpha)
)
.round()
.clip(max=beta)
).to(dtype)
idx = relative_position.clone()
if idx.dtype in [torch.float32, torch.float64]:
# round(x) when |x| <= alpha
idx = idx.round().to(dtype)
# assign the value when |x| > alpha
idx[not_mask] = y_out
return idx
@danieltudosiu yeah it is the same
i think relative positional bias, especially generalizing it to N-dimensions, may have to belong in a separate repository
@danieltudosiu are you working with video?
@danieltudosiu yeah it is the same
i think relative positional bias, especially generalizing it to N-dimensions, may have to belong in a separate repository
Nice, yes, that would be best, but it is quite niche, mainly useful in ensemble anomaly detections.
@danieltudosiu are you working with video?
Volumetric medical imaging
@danieltudosiu ohh got it
yea, let me think about how to tackle this in another repository then, for both space (3d) and potentially even space + time (4d) - don't think we'll have to go beyond that
i've been tired of redoing the code for 2d relative positional bias across repositories anyways, so probably a good time to generalize it haha
@danieltudosiu ohh got it
yea, let me think about how to tackle this in another repository then, for both space (3d) and potentially even space + time (4d) - don't think we'll have to go beyond that
i've been tired of redoing the code for 2d relative positional bias across repositories anyways, so probably a good time to generalize it haha
I do not know how useful the next idea would be but I would start by decomposing the attention mechanism in all its additions and write some kind of wrapper let's call it "Attention Enchancer" that calls those addition methods one by one and modifies their input/output :-?
Think of it like the Figure 1 from the paper I referenced in the issue.
@danieltudosiu ohh got it yea, let me think about how to tackle this in another repository then, for both space (3d) and potentially even space + time (4d) - don't think we'll have to go beyond that i've been tired of redoing the code for 2d relative positional bias across repositories anyways, so probably a good time to generalize it haha
I do not know how useful the next idea would be but I would start by decomposing the attention mechanism in all its additions and write some kind of wrapper let's call it "Attention Enchancer" that calls those addition methods one by one and modifies their input/output :-?
do you mean something like https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L1259 ?
btw, have you ever tried the contextual bias that was presented in the paper? i'm curious how that performs
@lucidrains I don't follow in regards to DALLE 2 :-? could you help me understand
We didn't try the contextual bias since it showed marginal improvement compared to the bias one and deemed it unnecessary engineering time.
ohh thanks for letting me know! yeah I guess static bias is good enough 🤔
ok maybe I'll build out the relative positional bias library when I find some time, as a separate repo, and with all the features like max cutoffs etc
Hi @lucidrains,
Thanks for this fantastic trove of transformers <3
I am mainly working with VQ-VAEs and in my experience, this paper [1] made constant improvements in the morphological correctness of the samples.
If you need help I can give you some pointers on how to start and what's needed. My coworker and I already got the "Bias" version working with the x-transformer but the "Context" one required too much modification of the Attention to warrant its implementation.
Cheers!
[1] Wu, K., Peng, H., Chen, M., Fu, J. and Chao, H., 2021. Rethinking and improving relative position encoding for vision transformer. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 10033-10041).