leondgarse / Keras_insightface

Insightface Keras implementation
MIT License
230 stars 56 forks source link

[Question] How to train with .tfrecord #125

Closed whysetiawan closed 7 months ago

whysetiawan commented 7 months ago

Hi, i found this datasets on kaggle but i couldn't use with tt.Train

datasets: https://www.kaggle.com/datasets/jasonhcwong/faces-ms1m-refine-v2-112x112-tfrecord

leondgarse commented 7 months ago

That needs some modify.

whysetiawan commented 7 months ago

i have been modify data.py and also train.py since *.tfrecord is identified as distill knowledge. do you mind if i create a PR for this changes so others can also use this?

leondgarse commented 7 months ago

That would be welcome! I'm just thinking how to add it compiling with current other data_path formats. The issue of using tfrecord is that, we cannot get the required info num_classes=85742 from it. Mostly needs a separate config file, or reading from a specific header field. Currently we can just modify fitting this particular dataset, not considering other general situations.

whysetiawan commented 7 months ago

what i'm doing is loop the entire dataset using ds.as_numpy_iterator() since the header isn't contains num_classes and num_images

leondgarse commented 7 months ago

May also try if it's an obvious better than extracting them to a folder, like faces_ms1m_refine_v2_112x112, and use that for training.

import os
import tensorflow as tf

filenames = tf.data.TFRecordDataset.list_files("/kaggle/input/faces-ms1m-refine-v2-112x112-tfrecord/faces_ms1m_refine_v2_112x112-*.tfrecord")
train_ds = tf.data.TFRecordDataset(filenames, num_parallel_reads = tf.data.AUTOTUNE)
feature_description = {'image_raw': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64)}

def parse_tfrecord_fn(example):
    example = tf.io.parse_single_example(example, feature_description)
    return example["image_raw"], int(example['label'])
train_ds = train_ds.map(parse_tfrecord_fn,num_parallel_calls = tf.data.AUTOTUNE)

save_path = "faces_ms1m_refine_v2_112x112"
count_dict = {}  # Record how many have been written to a sub-class folder
for image, label in train_ds.as_numpy_iterator():
    image_save_path = os.path.join(save_path, str(label))
    if not os.path.exists(image_save_path):
    count_dict[label] = count_dict.get(label, -1) + 1
    with open(os.path.join(image_save_path, '{}.jpg'.format(count_dict[label])), 'wb') as ff:
whysetiawan commented 7 months ago

126 Here is my PR

leondgarse commented 7 months ago

Merged, may check the format later. :)

leondgarse commented 7 months ago

Just have that part reformatted, added a function build_basic_dataset_from_tfrecord parallel with build_basic_dataset_from_data_path, that supporting data_path=xxxx/*.tfrecord as input. The format *.tfrecord is used distinguishing from distill one. data_path can also be a built dataset now, like:

import losses, train, models, data
dataset = data.build_basic_dataset_from_tfrecord(data_path, classes=85742, total_images=5822653)[0]
tt = train.Train(data_path=dataset, ...)

But I didn't test the tfrecord one...