leondgarse / Keras_insightface

Insightface Keras implementation
MIT License
240 stars 56 forks source link

[Question] How to train with .tfrecord #125

Closed whysetiawan closed 11 months ago

whysetiawan commented 12 months ago

Hi, i found this datasets on kaggle but i couldn't use with tt.Train

datasets: https://www.kaggle.com/datasets/jasonhcwong/faces-ms1m-refine-v2-112x112-tfrecord

leondgarse commented 11 months ago

That needs some modify.

whysetiawan commented 11 months ago

i have been modify data.py and also train.py since *.tfrecord is identified as distill knowledge. do you mind if i create a PR for this changes so others can also use this?

leondgarse commented 11 months ago

That would be welcome! I'm just thinking how to add it compiling with current other data_path formats. The issue of using tfrecord is that, we cannot get the required info num_classes=85742 from it. Mostly needs a separate config file, or reading from a specific header field. Currently we can just modify fitting this particular dataset, not considering other general situations.

whysetiawan commented 11 months ago

what i'm doing is loop the entire dataset using ds.as_numpy_iterator() since the header isn't contains num_classes and num_images

leondgarse commented 11 months ago

May also try if it's an obvious better than extracting them to a folder, like faces_ms1m_refine_v2_112x112, and use that for training.

import os
import tensorflow as tf

filenames = tf.data.TFRecordDataset.list_files("/kaggle/input/faces-ms1m-refine-v2-112x112-tfrecord/faces_ms1m_refine_v2_112x112-*.tfrecord")
train_ds = tf.data.TFRecordDataset(filenames, num_parallel_reads = tf.data.AUTOTUNE)
feature_description = {'image_raw': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64)}

def parse_tfrecord_fn(example):
    example = tf.io.parse_single_example(example, feature_description)
    return example["image_raw"], int(example['label'])
train_ds = train_ds.map(parse_tfrecord_fn,num_parallel_calls = tf.data.AUTOTUNE)

save_path = "faces_ms1m_refine_v2_112x112"
count_dict = {}  # Record how many have been written to a sub-class folder
for image, label in train_ds.as_numpy_iterator():
    image_save_path = os.path.join(save_path, str(label))
    if not os.path.exists(image_save_path):
        os.makedirs(image_save_path)
    count_dict[label] = count_dict.get(label, -1) + 1
    with open(os.path.join(image_save_path, '{}.jpg'.format(count_dict[label])), 'wb') as ff:
        ff.write(bb)
whysetiawan commented 11 months ago

126 Here is my PR

leondgarse commented 11 months ago

Merged, may check the format later. :)

leondgarse commented 11 months ago

Just have that part reformatted, added a function build_basic_dataset_from_tfrecord parallel with build_basic_dataset_from_data_path, that supporting data_path=xxxx/*.tfrecord as input. The format *.tfrecord is used distinguishing from distill one. data_path can also be a built dataset now, like:

import losses, train, models, data
...
dataset = data.build_basic_dataset_from_tfrecord(data_path, classes=85742, total_images=5822653)[0]
tt = train.Train(data_path=dataset, ...)
...

But I didn't test the tfrecord one...