2021-Creative-Study / creative-study

1 stars 0 forks source link

TFRecord file format #13

Open mingxoxo opened 3 years ago

mingxoxo commented 3 years ago

출처 : https://digitalbourgeois.tistory.com/50

TFRecord file format

TFRecord 파일은 텐서플로우로 딥러닝 학습을 하는데 필요한 데이터들을 보관하기 위한 데이터 포맷

장점

공부하는 코드에서 TFRecord 변환 부분


def _serialize_image(path, transform=None):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [CFG.img_size, CFG.img_size])
    image = tf.cast(image, tf.uint8)

    if transform is not None:
        image = transform(image=image.numpy())['image']

    return tf.image.encode_jpeg(image).numpy()

def _serialize_sample(image, image_name, label):
    feature = {
        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
        'image_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_name])),
        'complex': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[0]])),
        'frog_eye_leaf_spot': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[1]])),
        'powdery_mildew': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[2]])),
        'rust': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[3]])),
        'scab': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[4]])),
        'healthy': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[5]]))}
    sample = tf.train.Example(features=tf.train.Features(feature=feature))
    return sample.SerializeToString()

def serialize_fold(fold, name, transform=None, bar=None):
    samples = []

    for image_name, labels in fold.iterrows():
        path = os.path.join(CFG.root, image_name)
        image = _serialize_image(path, transform=transform)
        samples.append(_serialize_sample(image, image_name.encode(), labels))

    with tf.io.TFRecordWriter(name + '.tfrec') as writer:
        [writer.write(x) for x in samples]

    if bar is not None:
        bar.update(1)