will-maclean / sb3-burn

Implementation of stable-baselines3 in rust with burn
MIT License
11 stars 0 forks source link

Implement SAC #34

Open will-maclean opened 2 weeks ago

will-maclean commented 2 weeks ago

Starting point:

#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub enum EntCoefSetup {
    Static(f32),
    Trainable(f32)
}

pub enum EntCoef<B: AutodiffBackend> {
    Static(f32),
    Trainable(Tensor<B, 1>, OptimizerAdaptor<Adam<B::InnerBackend>, Tensor<B, 1>, B>)
}

impl<B: AutodiffBackend> EntCoef<B> {
    fn to_float(&self) -> f32 {
        match self {
            EntCoef::Static(v) => *v,
            EntCoef::Trainable(t, _) => t.clone().exp().into_scalar().elem(),
        }
    }
} 

pub struct SACAgent<O: SimpleOptimizer<B::InnerBackend>, B: AutodiffBackend> {
    pub net: SACNet<B>,
    pub optim: OptimizerAdaptor<O, SACNet<B>, B>,
    pub config: SACConfig,
    pub target_entropy: f32,
    pub ent_coef: EntCoef<B>,
    pub observation_space: ObsSpace,
    pub action_space: ActionSpace,
}

#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub enum ActionNoise{
    None,
}

#[derive(Config)]
pub struct SACConfig {
    #[config(default = 0.05)]
    tau: f32,
    #[config(default = 0.99)]
    gamma: f32,
    action_noise: ActionNoise,
    #[config(default = 1)]
    target_update_interval: usize,
    ent_coef: EntCoefSetup,
    target_entropy: Option<f32>,
    #[config(default = false)]
    use_sde: bool,
    #[config(default = 1)]
    sde_sample_freq: usize,
    #[config(default = false)]
    use_sde_at_warmup: bool,
}