Closed whysetiawan closed 11 months ago
That needs some modify.
Add a process_func
converting labels to one-hot. May like:
...
train_ds = train_ds.map(parse_tfrecord_fn,num_parallel_calls = tf.data.AUTOTUNE)
# After train_ds
num_classes = 85742
total_images = 5822653
process_func = lambda imm, label: (imm, tf.one_hot(label, depth=num_classes, dtype=tf.int32))
ds = train_ds.map(process_func, num_parallel_calls=AUTOTUNE)
i have been modify data.py
and also train.py
since *.tfrecord is identified as distill knowledge. do you mind if i create a PR for this changes so others can also use this?
That would be welcome!
I'm just thinking how to add it compiling with current other data_path formats. The issue of using tfrecord is that, we cannot get the required info num_classes=85742
from it. Mostly needs a separate config file, or reading from a specific header field. Currently we can just modify fitting this particular dataset, not considering other general situations.
what i'm doing is loop the entire dataset using ds.as_numpy_iterator()
since the header isn't contains num_classes and num_images
May also try if it's an obvious better than extracting them to a folder, like faces_ms1m_refine_v2_112x112
, and use that for training.
import os
import tensorflow as tf
filenames = tf.data.TFRecordDataset.list_files("/kaggle/input/faces-ms1m-refine-v2-112x112-tfrecord/faces_ms1m_refine_v2_112x112-*.tfrecord")
train_ds = tf.data.TFRecordDataset(filenames, num_parallel_reads = tf.data.AUTOTUNE)
feature_description = {'image_raw': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64)}
def parse_tfrecord_fn(example):
example = tf.io.parse_single_example(example, feature_description)
return example["image_raw"], int(example['label'])
train_ds = train_ds.map(parse_tfrecord_fn,num_parallel_calls = tf.data.AUTOTUNE)
save_path = "faces_ms1m_refine_v2_112x112"
count_dict = {} # Record how many have been written to a sub-class folder
for image, label in train_ds.as_numpy_iterator():
image_save_path = os.path.join(save_path, str(label))
if not os.path.exists(image_save_path):
os.makedirs(image_save_path)
count_dict[label] = count_dict.get(label, -1) + 1
with open(os.path.join(image_save_path, '{}.jpg'.format(count_dict[label])), 'wb') as ff:
ff.write(bb)
Merged, may check the format later. :)
Just have that part reformatted, added a function build_basic_dataset_from_tfrecord
parallel with build_basic_dataset_from_data_path
, that supporting data_path=xxxx/*.tfrecord
as input. The format *.tfrecord
is used distinguishing from distill one. data_path
can also be a built dataset now, like:
import losses, train, models, data
...
dataset = data.build_basic_dataset_from_tfrecord(data_path, classes=85742, total_images=5822653)[0]
tt = train.Train(data_path=dataset, ...)
...
But I didn't test the tfrecord one...
Hi, i found this datasets on kaggle but i couldn't use with
tt.Train
datasets: https://www.kaggle.com/datasets/jasonhcwong/faces-ms1m-refine-v2-112x112-tfrecord