google / balloon-learning-environment

The Balloon Learning Environment - flying stratospheric balloons with deep reinforcement learning.
Apache License 2.0
119 stars 14 forks source link

Can you tell me where the error might be?This is the code I wrote to train the VAE using the dataset ble_wind_field/small. #25

Closed 1jskk closed 5 days ago

1jskk commented 3 months ago

This is the code I wrote to train the VAE using the dataset ble_wind_field/small. The error mentioned in the report is as follows. Can you tell me where the error might be?

import os.path as osp import jax import jax.numpy as jnp import optax from flax import linen as nn from flax.training import train_state import tensorflow as tf #导入tensorflow模块用于解析数据集,数据集为tfrecord格式较为特殊,要用tensorflow库解析 from tensorflow.data import TFRecordDataset from balloon_learning_environment.generative.vae import WindFieldVAE, FieldShape, VAEOutput #导入VAE中的重点模块,好让之后代码参数与之匹配 from absl import app from absl import flags from balloon_learning_environment import train_lib from balloon_learning_environment.env.rendering import matplotlib_renderer from balloon_learning_environment.utils import run_helpers import gym import matplotlib import numpy as np

定义命令行参数

flags.DEFINE_string('agent', 'dqn', 'Type of agent to create.') flags.DEFINE_string('env_name', 'BalloonLearningEnvironment-v0', 'Name of environment to create.') flags.DEFINE_integer('num_iterations', 200, 'Number of episodes to train for.') flags.DEFINE_integer('max_episode_length', 960, 'Maximum number of steps per episode. Assuming 2 days, ' 'with each step lasting 3 minutes.') flags.DEFINE_string('base_dir', None, 'Directory where to store statistics/images.') flags.DEFINE_integer( 'run_number', 1, 'When running multiple agents in parallel, this number ' 'differentiates between the runs. It is appended to base_dir.') flags.DEFINE_string( 'wind_field', 'generative', 'The wind field type to use. See the _WIND_FIELDS dict below for options.') flags.DEFINE_string('agent_gin_file', None, 'Gin file for agent configuration.') flags.DEFINE_multi_string('collectors', ['console'], 'Collectors to include in metrics collection.') flags.DEFINE_multi_string('gin_bindings', [], 'Gin bindings to override default values.') flags.DEFINE_string( 'renderer', None, 'The renderer to use. Note that it is fastest to have this set to None.') flags.DEFINE_integer( 'render_period', 10, 'The period to render with. Only has an effect if renderer is not None.') flags.DEFINE_integer( 'episodes_per_iteration', 50, 'The number of episodes to run in one iteration. Checkpointing occurs ' 'at the end of each iteration.') flags.mark_flag_as_required('base_dir') FLAGS = flags.FLAGS

定义渲染器

_RENDERERS = { 'matplotlib': matplotlib_renderer.MatplotlibRenderer, }

解析TFRecord文件

def _parse_function(proto): keys_to_features = { 'feature1': tf.io.FixedLenFeature([], tf.float32), 'feature2': tf.io.FixedLenFeature([], tf.float32),

从TFRecord文件中解析出一个名为feature1的特征,并且它的数据类型是浮点数float32

}
parsed_features = tf.io.parse_single_example(proto, keys_to_features) #将数据从序列化的形式解码为张量
return parsed_features

加载数据集

def load_dataset(tfrecord_path, batch_size=32, shuffle_buffer_size=10000): dataset = tf.data.TFRecordDataset(tfrecord_path) #创建数据集对象 dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.batch(batch_size) #按训练批次大小分批,我选择的是32,但此时执行存在内存溢出的报错 dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset

创建训练状态函数

class TrainState(train_state.TrainState): pass #直接继承父类的定义的mean和logvar,无需额外定义

在正式训练前,初始化和创建训练状态

def create_train_state(rng, model, learning_rate, input_shape): rng, z_rng = jax.random.split(rng) params = model.init(rng, jnp.ones(input_shape), z_rng)['params'] tx = optax.adam(learning_rate) return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

训练步骤函数

@jax.jit def train_step(state, batch, rng): def loss_fn(params): vae_output = state.apply_fn({'params': params}, batch, z_rng=rng) recon_loss = jnp.mean((vae_output.reconstruction - batch) * 2) kl_loss = -0.5 jnp.sum(1 + vae_output.encoder_output.logvar - vae_output.encoder_output.mean ** 2 - jnp.exp( vae_output.encoder_output.logvar)) total_loss = recon_loss + kl_loss return total_loss, (recon_loss, kl_loss)

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (recon_loss, kl_loss)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss, recon_loss, kl_loss

def main(_) -> None:

准备度量收集器gin文件和构造函数。

collector_constructors = train_lib.get_collector_data(FLAGS.collectors)
run_helpers.bind_gin_variables(FLAGS.agent,
                               FLAGS.agent_gin_file,
                               FLAGS.gin_bindings)

renderer = None
if FLAGS.renderer is not None:
    renderer = _RENDERERSFLAGS.renderer

wf_factory = run_helpers.get_wind_field_factory(FLAGS.wind_field)
env = gym.make(FLAGS.env_name,
               wind_field_factory=wf_factory,
               renderer=renderer)

agent = run_helpers.create_agent(
    FLAGS.agent,
    env.action_space.n,
    observation_shape=env.observation_space.shape)

base_dir = osp.join(FLAGS.base_dir, FLAGS.agent, str(FLAGS.run_number))

#初始化VAE模型和训练状态。
rng = jax.random.PRNGKey(0)
learning_rate = 1e-3
num_latents = 64
batch_size = 16
num_epochs = 10
tfrecord_path = '/home/ble_wind_field-train.tfrecord-00000-of-00001' #数据集存储路径

dataset = load_dataset(tfrecord_path, batch_size=batch_size)
dataset = iter(dataset)  #将数据集转换为迭代器。

field_shape = FieldShape()
vae_model = WindFieldVAE(num_latents=num_latents, field_shape=field_shape)
state = create_train_state(rng, vae_model, learning_rate, batch_size)

#开始气球学习环境的主训练循环
train_lib.run_training_loop(
    base_dir,
    env,
    agent,
    FLAGS.num_iterations,
    FLAGS.max_episode_length,
    collector_constructors,
    render_period=FLAGS.render_period,
    episodes_per_iteration=FLAGS.episodes_per_iteration)

#在同一个循环中训练VAE模型。
for epoch in range(num_epochs):
    for _ in range(100):  #这里我假设每个epoch有100个批次
        batch = next(dataset).numpy()  #从数据集迭代器中获取批次数据
        rng, step_rng = jax.random.split(rng)
        state, loss, recon_loss, kl_loss = train_step(state, batch, step_rng)
        print(f'Epoch {epoch}, Loss: {loss}, 重构损失: {recon_loss}, KL损失: {kl_loss}') #直接打印相关信息便于观测

#如果设置了基础目录,则保存渲染图像。
if FLAGS.base_dir is not None:
    image_save_path = osp.join(FLAGS.base_dir, 'balloon_path.png')
    img = env.render(mode='rgb_array')
    if isinstance(img, np.ndarray):
        matplotlib.image.imsave(image_save_path, img)

if name == 'main': app.run(main) 最新报错2