serengil / deepface

A Lightweight Face Recognition and Facial Attribute Analysis (Age, Gender, Emotion and Race) Library for Python
https://www.youtube.com/watch?v=WnUVYQP4h44&list=PLsS_1RYmYQQFdWqxQggXHynP1rqaYXv_E&index=1
MIT License
14.24k stars 2.19k forks source link

Is there a function to download all the weights? #481

Closed ambitious-octopus closed 2 years ago

ambitious-octopus commented 2 years ago

I had a look at the code, but there seems to be no function to download the weights of all the networks. I made one. In case there is a need, I can share it.

serengil commented 2 years ago

deepface downloads just needed or required weight. it does not download all weights purposefully.

ambitious-octopus commented 2 years ago

Yes, okay. In some circumstances, you want the weights before you make inference to have the model ready for later. In my case: building a docker container with the weights already in place.

dre2004 commented 2 years ago

@Kubasinska care to share that code? I'm effectively doing the same thing, when building a container I would like to include the weight files pre-downloaded so the first run doesn't take minutes to respond.

Terminazor commented 2 years ago

I had a look at the code, but there seems to be no function to download the weights of all the networks. I made one. In case there is a need, I can share it.

@Kubasinska Can you please share you code? Like @dre2004 I would like to have the models preloaded before the first use.

ambitious-octopus commented 2 years ago

Sorry, I'm a bit late; here is the code. Basically, these are the same functions you find in the source code but grouped! If the maintainers think this is a good feature, I can submit a PR.

from deepface.commons import functions
import gdown
import os

def get_age_model(url = 'https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5'):
    home = functions.get_deepface_home()
    if os.path.isfile(home+'/.deepface/weights/age_model_weights.h5') != True:
        print("age_model_weights.h5 will be downloaded...")

        output = home+'/.deepface/weights/age_model_weights.h5'
        gdown.download(url, output, quiet=False)

def get_emotion_model(url = 'https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5'):
    home = functions.get_deepface_home()
    if os.path.isfile(home+'/.deepface/weights/facial_expression_model_weights.h5') != True:
        print("facial_expression_model_weights.h5 will be downloaded...")

        output = home+'/.deepface/weights/facial_expression_model_weights.h5'
        gdown.download(url, output, quiet=False)

def get_gender_model(url = 'https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5'):
    home = functions.get_deepface_home()

    if os.path.isfile(home+'/.deepface/weights/gender_model_weights.h5') != True:
        print("gender_model_weights.h5 will be downloaded...")

        output = home+'/.deepface/weights/gender_model_weights.h5'
        gdown.download(url, output, quiet=False)

def get_race_model(url = 'https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5'):
    home = functions.get_deepface_home()

    if os.path.isfile(home+'/.deepface/weights/race_model_single_batch.h5') != True:
        print("race_model_single_batch.h5 will be downloaded...")

        output = home+'/.deepface/weights/race_model_single_batch.h5'
        gdown.download(url, output, quiet=False)
nunenuh commented 2 years ago

Another example of download script

from deepface.commons import functions
import gdown

from pathlib import Path

data = {
    'vgg_face': 'https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5',
    'retinaface': 'https://github.com/serengil/deepface_models/releases/download/v1.0/retinaface.h5',
    'arcface': 'https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5',
    'deefpid': 'https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5',
    'facenet': 'https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5',
    'facenet512': 'https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5',
    'openface': 'https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5',

    'age_model': 'https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5',
    'emotion_model': 'https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5',
    'gender_model': 'https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5',
    'race_model': 'https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5',

    'ssd_iter': 'https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel',
    'ssd_proto': 'https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt',

}

home_dir = Path(functions.get_deepface_home())
base_dir = home_dir.joinpath('.deepface/weights/')

def download(name_from_data:str):
    url:str = data.get(name_from_data, None)
    # print(url)
    filename:str = Path(url).name
    path:Path = base_dir.joinpath(filename)
    if url != None:
        if path.is_file()!= True:
            print("{} will be downloaded...".format(url))
            gdown.download(url, str(path), quiet=False)
        else:
            print("Cancel Download: {} already exists".format(path))
    else:
        print("Error: URL {} is None".format(filename))

def get_ssd():
    download("ssd_iter")
    download("ssd_proto")

def get_vgg_face():
    download('vgg_face')

def get_retinaface():
    download("retinaface")

def get_arcface():
    download("arcface")

def get_deefpid():
    download("deefpid")

def get_facenet():
    download("facenet")

def get_facenet512():
    download("facenet512")

def get_openface():
    download("openface")

if __name__ == "__main__":
    get_ssd()
    get_vgg_face()