open-mmlab / mmaction2

OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark
https://mmaction2.readthedocs.io
Apache License 2.0
4.21k stars 1.23k forks source link

[Bug] The Lables generation process seems to mismatch the code released by the ActionCLIP authors. #2746

Open NICE-FUTURE opened 11 months ago

NICE-FUTURE commented 11 months ago

Branch

main branch (1.x version, such as v1.0.0, or dev-1.x branch)

Prerequisite

Environment

sys.platform: win32
Python: 3.10.10 | packaged by Anaconda, Inc. | (main, Mar 21 2023, 18:39:17) [MSC v.1916 64 bit (AMD64)]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0: NVIDIA GeForce RTX 4080
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.36.32535 版
GCC: n/a
PyTorch: 2.0.1
PyTorch compiling details: PyTorch built with:
  - C++ Version: 199711
  - MSVC 193431937
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 2019
  - LAPACK is enabled (usually provided by MKL)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.7
  - Magma 2.5.4
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=C:/cb/pytorch_1000000000000/work/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj /FS -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=OFF, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.15.2
OpenCV: 4.8.1
MMEngine: 0.7.3
MMAction2: 1.1.0+f1cda75
MMCV: 2.0.0
MMDetection: 3.0.0
MMPose: 1.0.0

Describe the bug

1. Implementation of the MMAction2

1.1 Labels Generation

link: https://github.com/open-mmlab/mmaction2/blob/4d6c93474730cad2f25e51109adcf96824efc7a3/projects/actionclip/models/actionclip.py#L162-L163

If i==j then gt[i,j]=1 else gt[i,j]=0

Only the diagonal elements of the similarity matrix have a target value of 1

def generate_gt_v1(labels:List[int]):
    """ The lables generation process in MMAction2
    Args:
        labels (List[int]): each element is a class id of a sample.
    """
    gt = np.arange(len(labels))
    return gt

1.2 Loss Function

So, MMAction2 uses Cross-entropy as the default similarity loss function.

link: https://github.com/open-mmlab/mmaction2/blob/4d6c93474730cad2f25e51109adcf96824efc7a3/projects/actionclip/models/actionclip.py#L165-L166

loss_imgs = CrossEntropy(logits_per_image, gt)
loss_texts = CrossEntropy(logits_per_text, gt)

2. Implementation of the code released by the ActionCLIP authors

2.1 Labels Generation

link: https://github.com/sallymmx/ActionCLIP/blob/31c34df17dce917d67127b7fb155922c4744f680/utils/tools.py#L7

if samples x_i and x_j have the same label then gt[i,j]=1 else gt[i,j]=0

def generate_gt_v2(labels:List[int]):
    """ The lables generation process in ActionCLIP
    Args:
        labels (List[int]): each element is a class id of a sample.
    """
    num = len(labels)
    gt = np.zeros(shape=(num,num))
    for i in range(num):
        for j in range(num):
            if labels[j] == labels[i]:
                gt[i,j] = 1
    return gt

2.2 Loss Function

So, ActionCLIP uses KL Loss as the default similarity loss function.

link: https://github.com/sallymmx/ActionCLIP/blob/31c34df17dce917d67127b7fb155922c4744f680/train.py#L177

loss_imgs = KLLoss(logits_per_image, gt)
loss_texts = KLLoss(logits_per_text, gt)

Does it have better performance or higher efficiency for the current implementation in MMAction2? Looking forward to the answer~

Reproduces the problem - code sample

No response

Reproduces the problem - command or script

No response

Reproduces the problem - error message

No response

Additional information

No response

thancaocuong commented 4 months ago

@NICE-FUTURE any update on this issue? I also see the different between these 2 implementations