pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.14k stars 6.94k forks source link

Add batch visualization function to `torchvision.utils` #8201

Open vahvero opened 9 months ago

vahvero commented 9 months ago

🚀 The feature

Currently vision models commonly return dictionary

model = fasterrcnn_resnet50_fpn_v2(
    weights=FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1,
)
dog_int = io.read_image("dog.jpeg")
batch = f.convert_image_dtype(dog_int)
response = model(batch.unsqueeze(0))
# [{'boxes': tensor([[ 61.9160,  49.2223, 185.7204, 184.4109],
#          [143.5477, 143.5992, 175.1468, 184.4712],
#          [136.6003, 217.3529, 161.4755, 223.8254]]),
#  'labels': tensor([18, 20, 57]),
#  'scores': tensor([0.9989, 0.1069, 0.0699])}]

This is visualized by torchvision by

dog_image = utils.draw_bounding_boxes(
    dog_int, response[0]["boxes"][response[0]["scores"] > 0.3]
)
f.to_pil_image(dog_image)

Motivation, pitch

However, there is no currenly a way to visualize whole batch with a function in utilities, eg. continuing from previous

with torch.no_grad():
    imgs = [img.detach().clone() for _ in range(16)]
    model.eval()
    response = model(imgs)
# The input is float tensor batch
# output is equal length list of output dictionaries

There is no currenly available method to visualize this batch for user.

I suggest an utils function

def visualize_batch(image_batch: Tensor | list[Tensor], batch_response: list[dict[str, Tensor]], **visualization_arguments):
    """Function visualizes image batch in a suitable grid and returns result as a tensor

       Arguments:
            image_batch (Tensor): Float tensor batch, internally transform to uint8 for visualization utils
            batch_response: List of response dictionaries, works for both mask and rcnn models as well as training batches by inferring correct behavior from dictionary keys.
    """
    # ... implementation

where keyword arguments relate to arguments in current utils functions.

Alternatives

User could follow utilities example at example which implements matplotlib function show. This is non-ideal, as users seems to have need for similar functionality without related boilderplate code.

Additional context

I am willing to contribute given green light.

NicolasHug commented 9 months ago

Thanks for the feature request @vahvero

I think what you're trying to do should be reasonably achievable by:

LMK if this isn't what you're looking for

vahvero commented 9 months ago

@NicolasHug I personally feel that the parameter count would not be an issue here. Torchvision models already take a lot of default keyword parameters, so the function call signature would not be out of place for users. By using reasonable defaults, most of the functionality could be abstracted in a manner not different for model __init__ methods such as FasterRCNN.

I think that this type of function would have tangible benefit for most users. Instead of the mentioned self implemented functionality, which I suspect every every single user has either copied from the example or abstracted themselves to achieve similar behavior to this feature request, the library would offer a standardized function for it.

For me, it seems a very commonly utilized behavior which has not for some reason included into the library despite including some drawing utilities.

Having though this later, I think function

def visualize_batch(
    image_batch: Tensor | list[Tensor], 
    batch_response: list[dict[str, Tensor]], 
    **visualization_arguments, # these are naturally expanded
) -> list[Tensor]:

would allow user to pass the tensors to whichever visualization library they are utilizing.