Open mingxoxo opened 3 years ago
출처 : https://digitalbourgeois.tistory.com/50
TFRecord 파일은 텐서플로우로 딥러닝 학습을 하는데 필요한 데이터들을 보관하기 위한 데이터 포맷
장점
공부하는 코드에서 TFRecord 변환 부분
def _serialize_image(path, transform=None): image = tf.io.read_file(path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, [CFG.img_size, CFG.img_size]) image = tf.cast(image, tf.uint8) if transform is not None: image = transform(image=image.numpy())['image'] return tf.image.encode_jpeg(image).numpy() def _serialize_sample(image, image_name, label): feature = { 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), 'image_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_name])), 'complex': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[0]])), 'frog_eye_leaf_spot': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[1]])), 'powdery_mildew': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[2]])), 'rust': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[3]])), 'scab': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[4]])), 'healthy': tf.train.Feature(int64_list=tf.train.Int64List(value=[label[5]]))} sample = tf.train.Example(features=tf.train.Features(feature=feature)) return sample.SerializeToString() def serialize_fold(fold, name, transform=None, bar=None): samples = [] for image_name, labels in fold.iterrows(): path = os.path.join(CFG.root, image_name) image = _serialize_image(path, transform=transform) samples.append(_serialize_sample(image, image_name.encode(), labels)) with tf.io.TFRecordWriter(name + '.tfrec') as writer: [writer.write(x) for x in samples] if bar is not None: bar.update(1)
TFRecord file format
TFRecord 파일은 텐서플로우로 딥러닝 학습을 하는데 필요한 데이터들을 보관하기 위한 데이터 포맷
장점
공부하는 코드에서 TFRecord 변환 부분