google-deepmind / deepmind-research

This repository contains implementations and illustrative code to accompany DeepMind publications
Apache License 2.0
13.29k stars 2.61k forks source link

Creating Own Dataset For 'Learning to simulate complex physics with Graph networks' #199

Open piba941 opened 3 years ago

piba941 commented 3 years ago

I am trying to evaluate the model using my own dataset but i am facing issues while generating and converting into TFrecords. Could you please help.

zxy-ml84 commented 3 years ago

same purpose, what your current progress?

jg8610 commented 3 years ago

Hi there, we are happy to help if we can.

Could you provide some more information on what your issues are?

alvarosg commented 3 years ago

Note that you can use your own datasets without writing them as "tf.records". Instead, tf.data.Dataset has a from_generator method, that lets you load a dataset using arbitrary python code, so you may replace these two lines here to load your dataset in whatever format you have it stored with from_generator and simply make the tensors look like those read from our records datasets :)

yjchoi1 commented 3 years ago

First, thanks to @alvarosg for the answer.

I have made very simple arbitrary data which resembles the structure of the parsed version of the test.tfrecord as follows:

particle_type =  np.array([3, 3, 3, 3])
key = 0
position = np.array([[[0.42, 0.55],
                             [0.41, 0.55],
                             [0.40, 0.54],
                             [0.39, 0.53]],
                            [[0.56, 0.14],
                             [0.56, 0.13],
                             [0.54, 0.12],
                             [0.53, 0.11]]])

What is a good way to make this a compatible dataset? I first tried to make a generator and put it into a from_generator, but I am not sure how to make this dataset compatible with the training data.

kks32 commented 3 years ago

The TFRecord uses the SequenceExample format, which can be generated as shown below:

# Import modules and this file should be outside learning_to_simulate code folder
import functools
import os
import json
import pickle

import tensorflow.compat.v1 as tf
import numpy as np

from learning_to_simulate import reading_utils

# Set datapath and validation set
data_path = './datasets/WaterDropSample'
filename = 'valid.tfrecord'

# Read metadata
def _read_metadata(data_path):
    with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
        return json.loads(fp.read())

# Fetch metadata
metadata = _read_metadata(data_path)

print(metadata)

# Read TFRecord
ds_org = tf.data.TFRecordDataset([os.path.join(data_path, filename)])
ds = ds_org.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))

# Convert to list
# @tf.function
def list_tf(ds):
    return(list(ds))

lds = list_tf(ds)

particle_types = []
keys = []
positions = []
for _ds in ds:
    context, features = _ds
    particle_types.append(context["particle_type"].numpy().astype(np.int64))
    keys.append(context["key"].numpy().astype(np.int64))
    positions.append(features["position"].numpy().astype(np.float32))

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Write TF Record
with tf.python_io.TFRecordWriter('test.tfrecord') as writer:

    for step, (particle_type, key, position) in enumerate(zip(particle_types, keys, positions)):
        seq = tf.train.SequenceExample(
                context=tf.train.Features(feature={
                    "particle_type": _bytes_feature(particle_type.tobytes()),
                    "key": _int64_feature(key)
                }),
                feature_lists=tf.train.FeatureLists(feature_list={
                    'position': tf.train.FeatureList(
                        feature=[_bytes_feature(position.flatten().tobytes())],
                    ),
                    'step_context': tf.train.FeatureList(
                        feature=[_bytes_feature(np.float32(step).tobytes())]
                    ),
                })
            )

        writer.write(seq.SerializeToString())

dt = tf.data.TFRecordDataset(['test.tfrecord'])
dt = dt.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))

# Check if the original TFRecord and the newly generated TFRecord are the same
for ((_ds_context, _ds_feature), (_dt_context, _dt_feature)) in zip(ds, dt):
    if not np.allclose(_ds_context["key"].numpy(), _dt_context["key"].numpy()):
        break

    if not np.allclose(_ds_context["particle_type"].numpy(), _dt_context["particle_type"].numpy()):
        break

    if not np.allclose(_ds_feature["position"].numpy(), _dt_feature["position"].numpy()):
        break

else:
    print("TFRecords are similar!")
yq60523 commented 2 years ago

Note that you can use your own datasets without writing them as "tf.records". Instead, tf.data.Dataset has a from_generator method, that lets you load a dataset using arbitrary python code, so you may replace these two lines here to load your dataset in whatever format you have it stored with from_generator and simply make the tensors look like those read from our records datasets :)

Thanks. Can you provide an explanation about the format of your datasets?

alvarosg commented 2 years ago

Thanks. Can you provide an explanation about the format of your datasets?

This varies a bit on a per dataset basis, but in the most simple case (dataset with only positions), it is mostly two dicts with fields:

context["particle_type"]  # shape [num_particles]
features["position"] # shape [num_steps, num particles, num_dimensions]

If you want to connect your own data to this, I would recommend to run these two lines:

    ds = tf.data.TFRecordDataset([os.path.join(data_path, f'{split}.tfrecord')])
    ds = ds.map(functools.partial(
        reading_utils.parse_serialized_simulation_example, metadata=metadata))

inspect the spec of the output, and then write your own tf.Dataset that outputs your data in the same format.

yq60523 commented 2 years ago

Thanks. Can you provide an explanation about the format of your datasets?

This varies a bit on a per dataset basis, but in the most simple case (dataset with only positions), it is mostly two dicts with fields:

context["particle_type"]  # shape [num_particles]
features["position"] # shape [num_steps, num particles, num_dimensions]

If you want to connect your own data to this, I would recommend to run these two lines:

    ds = tf.data.TFRecordDataset([os.path.join(data_path, f'{split}.tfrecord')])
    ds = ds.map(functools.partial(
        reading_utils.parse_serialized_simulation_example, metadata=metadata))

inspect the spec of the output, and then write your own tf.Dataset that outputs your data in the same format.

Many thanks to Alvaro. That's what I want.

Vesuvius6 commented 1 year ago

The GNS article mentioned generating the Water-3D dataset using SPlisHSPlasH. How did you do it? I want to make some test cases myself. thanks a lot!!!

BoyuanTang331 commented 1 year ago

what is step_context args for the data? is it forced to be add when creating the dataset

when I debug for open tfrecord, I see in each element the dict 1 has only position information, the step_context is not included