gfxdisp / ColorVideoVDP

Colour video VDP
MIT License
29 stars 6 forks source link

pycvvdp.cvvdp.predict/loss not working for batch size > 1 #15

Open zhou-wb opened 3 months ago

zhou-wb commented 3 months ago
I_ref = torch.rand((2, 1, 1080, 1920), dtype=torch.float32)
I_test = torch.rand((2, 1, 1080, 1920), dtype=torch.float32)

cvvdp = pycvvdp.cvvdp(display_name='standard_fhd')
JOD = cvvdp.loss(I_test, I_ref, dim_order="BCHW" )
{
    "name": "RuntimeError",
    "message": "The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1, 3, 1, 1080, 1920].  Tensor sizes: [2, 1, 1, 1080, 1920]",
    "stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[76], line 8
      5 I_test = torch.rand((2, 1, 1080, 1920), dtype=torch.float32)
      7 cvvdp = pycvvdp.cvvdp(display_name='standard_fhd')
----> 8 JOD = cvvdp.loss(I_test, I_ref, dim_order=\"BCHW\" )
     10 print(f'JOD: {JOD}')

File ~/anaconda3/envs/TORCH2.2CUDA11.8/lib/python3.12/site-packages/pycvvdp/cvvdp_metric.py:279, in cvvdp.loss(self, test_cont, reference_cont, dim_order, frames_per_second)
    276 def loss(self, test_cont, reference_cont, dim_order=\"BCFHW\", frames_per_second=0):
    278     test_vs = video_source_array( test_cont, reference_cont, frames_per_second, dim_order=dim_order, display_photometry=self.display_photometry )
--> 279     (Q_jod, stats) = self.predict_video_source(test_vs)
    280     return (10.-Q_jod)

File ~/anaconda3/envs/TORCH2.2CUDA11.8/lib/python3.12/site-packages/pycvvdp/cvvdp_metric.py:372, in cvvdp.predict_video_source(self, vid_source)
    370 if is_image:                
    371     R = torch.empty((1, 6, 1, height, width), device=self.device)
--> 372     R[:,0::2, :, :, :] = vid_source.get_test_frame(0, device=self.device, colorspace=met_colorspace)
    373     R[:,1::2, :, :, :] = vid_source.get_reference_frame(0, device=self.device, colorspace=met_colorspace)
    375 else: # This is video
    376     #if self.debug: print(\"Frame %d:\
----\" % ff)

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1, 3, 1, 1080, 1920].  Tensor sizes: [2, 1, 1, 1080, 1920]"
}
mantiuk commented 3 months ago

That is right. Batch sizes greater than 1 are currently not supported. If you could please explain your use case so that we can prioritize adding support for batches.

zhou-wb commented 3 months ago

Thanks for your good work and reply! I am trying to evaluate the images from the focal stack. So I use the batch channel to represent images of different focal distances. It could be faster to have batches supported but I guess I can also loop over the batches and calculate the mean for now.

mantiuk commented 1 month ago

Batch support is on our TODO list, but it may take some time before we can implement it.