pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.04k stars 6.93k forks source link

Torchvision Normalize zero checks and tensor creation in constructor #8565

Closed heth27 closed 1 month ago

heth27 commented 1 month ago

🚀 The feature

Move checks for zeros and tensor creation to init of torchvision transform normalize

https://pytorch.org/vision/main/_modules/torchvision/transforms/transforms.html#Normalize

Motivation, pitch

The mean and std attributes normally don't change, so checking if they are valid and creating tensors from the arguments, and then doing the calculations on tensors directly instead of calling F.normalize is faster. Setting the attributes as private should be enough to caution against changing their values.

Alternatives

Leave as is

Additional context

No response

NicolasHug commented 1 month ago

Hi @heth27 , thanks for opening the issue. Unfortunately, avoiding these checks is not as simple as it may sound, because we do intend the lower-level functionals to be usable as well. I.e. we want torchvision.transforms[.v2].functional.resize() to be usable and we need it to contain the same value checks as well. Because the transforms rely on the functionals, we put the validation checks at the functional level to avoid duplication, or complex code-paths.

I would be curious however to see some benchmark showing whether such checks have a significant impact on perf.

heth27 commented 1 month ago
from itertools import product
import torch
import torch.utils.benchmark as benchmark
from torch import Tensor
from torchvision.transforms.v2 import Normalize

class CustomNormalize(torch.nn.Module):
    """
    faster than the torchvision implementation https://pytorch.org/vision/main/_modules/torchvision/transforms/transforms.html#Normalize
    because the checks for zeros and creation of tensors from the arguments is only done once instead of at every call

    """

    def __init__(self, mean, std, dtype=torch.float32):
        super().__init__()

        self.register_buffer("mean",
                             torch.as_tensor(mean, dtype=dtype).view(-1, 1, 1)
                             , persistent=False)

        std = torch.as_tensor(std, dtype=dtype)
        if (std == 0).any():
            raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
        self.register_buffer("std",
                             std.view(-1, 1, 1)
                             , persistent=False)

    def forward(self, tensor: Tensor) -> Tensor:
        return torch.div(tensor - self.mean, self.std)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

initialized_custom_normalize = CustomNormalize(mean=1.0, std=1.0)
initialized_normalize = Normalize(mean=[1.0], std=[1.0])

# initialized_custom_normalize.to("cuda")

def custom_normalize(a):
    return initialized_custom_normalize(a)

def default_normalize(a):
    return initialized_normalize(a)

# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
# assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x)
# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024]  # , 10000]
channel_sizes = [3]
height = 256
width = 256
for batch, channel_size in product(sizes, channel_sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Normalize'
    sub_label = f'[{batch}, {channel_size}, {height}, {width}]'
    # x = torch.ones((batch, channel_size, height, width), device='cuda')
    x = torch.ones((batch, channel_size, height, width))
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt='custom_normalize(x)',
            setup='from __main__ import custom_normalize',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='custom_normalize',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='default_normalize(x)',
            setup='from __main__ import default_normalize',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='default_normalize',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

# results on cpu

# [--------------------------- Normalize ----------------------------]
#                            |  custom_normalize  |  default_normalize
# 1 threads: ---------------------------------------------------------
#       [1, 3, 256, 256]     |         165.1      |          193.9
#       [64, 3, 256, 256]    |       49127.9      |        29943.7
#       [1024, 3, 256, 256]  |      750098.6      |       458307.3
# 4 threads: ---------------------------------------------------------
#       [1, 3, 256, 256]     |          59.4      |           88.7
#       [64, 3, 256, 256]    |       17043.1      |        10409.8
#       [1024, 3, 256, 256]  |      253177.0      |       154099.5
# 16 threads: --------------------------------------------------------
#       [1, 3, 256, 256]     |         138.0      |          172.6
#       [64, 3, 256, 256]    |       13264.9      |         8819.0
#       [1024, 3, 256, 256]  |      181596.8      |       125650.9
# 32 threads: --------------------------------------------------------
#       [1, 3, 256, 256]     |         189.0      |          224.6
#       [64, 3, 256, 256]    |       11939.1      |         8131.4
#       [1024, 3, 256, 256]  |      168515.7      |       117112.4
#
# Times are in microseconds (us).

# results on cuda

# [--------------------------- Normalize ----------------------------]
#                            |  custom_normalize  |  default_normalize
# 1 threads: ---------------------------------------------------------
#       [1, 3, 256, 256]     |         30.9       |         113.9
#       [64, 3, 256, 256]    |        562.4       |         617.5
#       [1024, 3, 256, 256]  |       8926.5       |        8995.0
# 4 threads: ---------------------------------------------------------
#       [1, 3, 256, 256]     |         31.0       |         113.6
#       [64, 3, 256, 256]    |        562.0       |         618.3
#       [1024, 3, 256, 256]  |       8923.3       |        8994.7
# 16 threads: --------------------------------------------------------
#       [1, 3, 256, 256]     |         31.3       |         116.4
#       [64, 3, 256, 256]    |        562.0       |         618.5
#       [1024, 3, 256, 256]  |       8923.0       |        9025.7
# 32 threads: --------------------------------------------------------
#       [1, 3, 256, 256]     |         31.3       |         114.3
#       [64, 3, 256, 256]    |        562.0       |         616.4
#       [1024, 3, 256, 256]  |       8951.4       |        9020.7
#
# Times are in microseconds (us).

# results on v1 normalize and cpu

# [--------------------------- Normalize ----------------------------]
#                            |  custom_normalize  |  default_normalize
# 1 threads: ---------------------------------------------------------
#       [1, 3, 256, 256]     |         164.7      |          246.4    
#       [64, 3, 256, 256]    |       48191.7      |        33828.5    
#       [1024, 3, 256, 256]  |      758001.0      |       532972.7    
# 4 threads: ---------------------------------------------------------
#       [1, 3, 256, 256]     |          58.6      |          119.8    
#       [64, 3, 256, 256]    |       16709.8      |        12200.3    
#       [1024, 3, 256, 256]  |      253394.2      |       185741.1    
# 16 threads: --------------------------------------------------------
#       [1, 3, 256, 256]     |         140.8      |          242.2    
#       [64, 3, 256, 256]    |       13298.5      |        10944.7    
#       [1024, 3, 256, 256]  |      181970.6      |       154560.6    
# 32 threads: --------------------------------------------------------
#       [1, 3, 256, 256]     |         189.1      |          321.0    
#       [64, 3, 256, 256]    |       12012.4      |        10306.9    
#       [1024, 3, 256, 256]  |      169448.9      |       148928.1    
# 
# Times are in microseconds (us).
heth27 commented 1 month ago

So the performance hit on cuda seems to be the the 70 us for tensor creation, on cpu I don't know. My benchmark in my model was probably not properly synchronized. So on cuda not significant enough to worry about, thank you.