This is the code I wrote to train the VAE using the dataset ble_wind_field/small.
The error mentioned in the report is as follows.
Can you tell me where the error might be?
import os.path as osp
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training import train_state
import tensorflow as tf #导入tensorflow模块用于解析数据集,数据集为tfrecord格式较为特殊,要用tensorflow库解析
from tensorflow.data import TFRecordDataset
from balloon_learning_environment.generative.vae import WindFieldVAE, FieldShape, VAEOutput #导入VAE中的重点模块,好让之后代码参数与之匹配
from absl import app
from absl import flags
from balloon_learning_environment import train_lib
from balloon_learning_environment.env.rendering import matplotlib_renderer
from balloon_learning_environment.utils import run_helpers
import gym
import matplotlib
import numpy as np
定义命令行参数
flags.DEFINE_string('agent', 'dqn', 'Type of agent to create.')
flags.DEFINE_string('env_name', 'BalloonLearningEnvironment-v0',
'Name of environment to create.')
flags.DEFINE_integer('num_iterations', 200, 'Number of episodes to train for.')
flags.DEFINE_integer('max_episode_length', 960,
'Maximum number of steps per episode. Assuming 2 days, '
'with each step lasting 3 minutes.')
flags.DEFINE_string('base_dir', None,
'Directory where to store statistics/images.')
flags.DEFINE_integer(
'run_number', 1,
'When running multiple agents in parallel, this number '
'differentiates between the runs. It is appended to base_dir.')
flags.DEFINE_string(
'wind_field', 'generative',
'The wind field type to use. See the _WIND_FIELDS dict below for options.')
flags.DEFINE_string('agent_gin_file', None,
'Gin file for agent configuration.')
flags.DEFINE_multi_string('collectors', ['console'],
'Collectors to include in metrics collection.')
flags.DEFINE_multi_string('gin_bindings', [],
'Gin bindings to override default values.')
flags.DEFINE_string(
'renderer', None,
'The renderer to use. Note that it is fastest to have this set to None.')
flags.DEFINE_integer(
'render_period', 10,
'The period to render with. Only has an effect if renderer is not None.')
flags.DEFINE_integer(
'episodes_per_iteration', 50,
'The number of episodes to run in one iteration. Checkpointing occurs '
'at the end of each iteration.')
flags.mark_flag_as_required('base_dir')
FLAGS = flags.FLAGS
This is the code I wrote to train the VAE using the dataset ble_wind_field/small. The error mentioned in the report is as follows. Can you tell me where the error might be?
import os.path as osp import jax import jax.numpy as jnp import optax from flax import linen as nn from flax.training import train_state import tensorflow as tf #导入tensorflow模块用于解析数据集,数据集为tfrecord格式较为特殊,要用tensorflow库解析 from tensorflow.data import TFRecordDataset from balloon_learning_environment.generative.vae import WindFieldVAE, FieldShape, VAEOutput #导入VAE中的重点模块,好让之后代码参数与之匹配 from absl import app from absl import flags from balloon_learning_environment import train_lib from balloon_learning_environment.env.rendering import matplotlib_renderer from balloon_learning_environment.utils import run_helpers import gym import matplotlib import numpy as np
定义命令行参数
flags.DEFINE_string('agent', 'dqn', 'Type of agent to create.') flags.DEFINE_string('env_name', 'BalloonLearningEnvironment-v0', 'Name of environment to create.') flags.DEFINE_integer('num_iterations', 200, 'Number of episodes to train for.') flags.DEFINE_integer('max_episode_length', 960, 'Maximum number of steps per episode. Assuming 2 days, ' 'with each step lasting 3 minutes.') flags.DEFINE_string('base_dir', None, 'Directory where to store statistics/images.') flags.DEFINE_integer( 'run_number', 1, 'When running multiple agents in parallel, this number ' 'differentiates between the runs. It is appended to base_dir.') flags.DEFINE_string( 'wind_field', 'generative', 'The wind field type to use. See the _WIND_FIELDS dict below for options.') flags.DEFINE_string('agent_gin_file', None, 'Gin file for agent configuration.') flags.DEFINE_multi_string('collectors', ['console'], 'Collectors to include in metrics collection.') flags.DEFINE_multi_string('gin_bindings', [], 'Gin bindings to override default values.') flags.DEFINE_string( 'renderer', None, 'The renderer to use. Note that it is fastest to have this set to None.') flags.DEFINE_integer( 'render_period', 10, 'The period to render with. Only has an effect if renderer is not None.') flags.DEFINE_integer( 'episodes_per_iteration', 50, 'The number of episodes to run in one iteration. Checkpointing occurs ' 'at the end of each iteration.') flags.mark_flag_as_required('base_dir') FLAGS = flags.FLAGS
定义渲染器
_RENDERERS = { 'matplotlib': matplotlib_renderer.MatplotlibRenderer, }
解析TFRecord文件
def _parse_function(proto): keys_to_features = { 'feature1': tf.io.FixedLenFeature([], tf.float32), 'feature2': tf.io.FixedLenFeature([], tf.float32),
从TFRecord文件中解析出一个名为feature1的特征,并且它的数据类型是浮点数float32
加载数据集
def load_dataset(tfrecord_path, batch_size=32, shuffle_buffer_size=10000): dataset = tf.data.TFRecordDataset(tfrecord_path) #创建数据集对象 dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.batch(batch_size) #按训练批次大小分批,我选择的是32,但此时执行存在内存溢出的报错 dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset
创建训练状态函数
class TrainState(train_state.TrainState): pass #直接继承父类的定义的mean和logvar,无需额外定义
在正式训练前,初始化和创建训练状态
def create_train_state(rng, model, learning_rate, input_shape): rng, z_rng = jax.random.split(rng) params = model.init(rng, jnp.ones(input_shape), z_rng)['params'] tx = optax.adam(learning_rate) return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
训练步骤函数
@jax.jit def train_step(state, batch, rng): def loss_fn(params): vae_output = state.apply_fn({'params': params}, batch, z_rng=rng) recon_loss = jnp.mean((vae_output.reconstruction - batch) * 2) kl_loss = -0.5 jnp.sum(1 + vae_output.encoder_output.logvar - vae_output.encoder_output.mean ** 2 - jnp.exp( vae_output.encoder_output.logvar)) total_loss = recon_loss + kl_loss return total_loss, (recon_loss, kl_loss)
def main(_) -> None:
准备度量收集器gin文件和构造函数。
if name == 'main': app.run(main)