In the implementation of dpc algorithm, for code of this line,
as we know the size of dist_matrix is B N N, because we want to get the max distance of each token, but if we flat the dist_matrix, we will only get the max distance of each batch.
dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
we can change the code to
dist_max = dist_matrix.max(dim=-1)[0][:, :, None]
In the implementation of dpc algorithm, for code of this line, as we know the size of dist_matrix is B N N, because we want to get the max distance of each token, but if we flat the dist_matrix, we will only get the max distance of each batch.
dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
we can change the code to
dist_max = dist_matrix.max(dim=-1)[0][:, :, None]