matlab-deep-learning / MATLAB-Deep-Learning-Model-Hub

Discover pretrained models for deep learning in MATLAB
https://www.mathworks.com/solutions/deep-learning.html
Other
418 stars 107 forks source link

Convert Video classification network into dlnetwork #3

Open mrclasalvia opened 1 year ago

mrclasalvia commented 1 year ago

Hi Everyone! I am writing to you to ask for help concerning the video classification models on the Matlab model hub: how can I convert them into dlnetworks, add and change layers? I didn't find anything to perform such operations. Namely, I would like to visualize and even modify the network, however, I do not know how.

mwbpatel commented 1 year ago

@mrclasalvia

If you are asking about the slowFastVideoClassifier, inflated3dVideoClassifier, or r2plus1dVideoClassifier, then the underlying networks are not user accessible.

Besides visualizing the network, you mentioned modifying the network. Could you say more about why you’d like to do that?

Is it to improve the accuracy or is it to reduce the memory usage? Do you want to design your own network that is completely different from the type used in SlowFast, for example, or do you just want to make small tweaks?

This information will help us enhance our functionality to better suit your needs.

mrclasalvia commented 1 year ago

@mwbpatel I would like to visualize them to understand both which layers and dimensions concern the networks at each stage (yes, the networks are the ones you mentioned) and maybe I'd like to customize them for instance inserting custom attention layers I developed, to improve accuracy and so on. Could you also add the Multiscale ViT for video classification? Thank you for you assistance

cuixing158 commented 3 months ago

hi, @mrclasalvia Currently MATLAB does lack a certain amount of flexibility in video classification and does not allow the user to manipulate the underlying network structure. It can only be operated by a limited number of object functions provided. I think TMW will consider the user's feeling in the future, and will strengthen the improvement in this aspect.

On the other hand, it is possible to open up the outer layer of the wrapper and load the data layer into the network so that you can use dlnetwork objects, such as the r(2+1)d network, directly:

which r2plus1dVideoClassifier.m

Open the m-file above and locate the last internal function in the file iTripwireR2Plus1DResnet3D18(), which can be seen to load a mat weights file to construct a dlnetwork object.

%------------------------------------------------------------------
function data = iTripwireR2Plus1DResnet3D18()
    % Check if support package is installed
    breadcrumbFile = 'vision.internal.cnn.supportpackages.IsR2Plus1DInstalled';
    fullPath = which(breadcrumbFile);
    if isempty(fullPath)
        name     = 'Computer Vision Toolbox Model for R(2+1)D Video Classification';
        basecode = 'RD_VIDEO';

        throwAsCaller(MException(message('nnet_cnn:supportpackages:InstallRequired', mfilename, name, basecode)));
    else
        pattern = fullfile(filesep, '+vision','+internal','+cnn','+supportpackages','IsR2Plus1DInstalled.m');
        idx     = strfind(fullPath, pattern);
        matfile = fullfile(fullPath(1:idx), 'data', 'r2plus1dPretrained_3d18.mat');
        data    = load(matfile);
    end
end

output dlnetwork object:

data.Network

dlnetwork - 属性:

     Layers: [120×1 nnet.cnn.layer.Layer]
Connections: [127×2 table]
 Learnables: [150×3 table]
      State: [74×3 table]
 InputNames: {'VideoInput'}
OutputNames: {'ClassificationLayer_Gemm_118_Softmax'}
Initialized: 1

使用 summary 查看摘要。