jiweiqi / CellBox.jl

CellBox in Julia
MIT License
3 stars 1 forks source link

Use a config file to record hyperparameters #8

Closed jiweiqi closed 3 years ago

jiweiqi commented 3 years ago

We can use a yaml file as input and a log file using the package of https://github.com/JuliaData/YAML.jl

Proposed contents

is_restart: false
ns: 5

tfinal: 20.0
ntotal: 20
batch_size: 8

lr: 1.e-3
weight_decay: 1.e-4

n_mu: 3

n_exp_train: 20
n_exp_val: 10
n_exp_test: 10
noise: 0.01

n_iter_max: 1000
n_plot: 20 # frequency of callback

n_iter_buffer: 50
n_iter_burnin: 100
n_iter_tol: 10
convergence_tol: 1e-8

sparsity: 0.6
drop_range:
   lb: -0.1
   hb: 0.1
jiweiqi commented 3 years ago

Implemented at https://github.com/jiweiqi/CellBox.jl/blob/93c044488d73e16b392ec87c183cef59c6720e5e/spilt/cellbox.jl#L8

jiweiqi commented 3 years ago

@DesmondYuan , I am not sure if the arg parse will cause problems when running locally. Here are some proposals to solve the issues based on my experience from other projects.

The code takes config from two sources: the command line or shell script; the config.yaml. We can use the following logic: if arg is specified in the command line, then take it from the command line, else take it from the config.yaml. This should make the code working well in local and server.

Some changes to the yaml file.

jiweiqi commented 3 years ago

An example header.jl that I am using

using Random, Plots
using Zygote, ForwardDiff
using OrdinaryDiffEq, DiffEqSensitivity
using LinearAlgebra
using Statistics
using ProgressBars, Printf
using Flux
using Flux.Optimise: update!
using Flux.Losses: mae
using BSON: @save, @load
using DelimitedFiles
using YAML

ENV["GKSwstype"] = "100"

cd(dirname(@__DIR__))
conf = YAML.load_file("./config.yaml")

expr_name = conf["expr_name"]
fig_path = string("./results/", expr_name, "/figs")
ckpt_path = string("./results/", expr_name, "/checkpoint")
config_path = "./results/$expr_name/config.yaml"

is_restart = Bool(conf["is_restart"])
ns = Int64(conf["ns"])
nr = Int64(conf["nr"])
lb = Float64(conf["lb"])
n_epoch = Int64(conf["n_epoch"])
n_plot = Int64(conf["n_plot"])
grad_max = Float64(conf["grad_max"])
maxiters = Int64(conf["maxiters"])

lr_max = Float64(conf["lr_max"])
lr_min = Float64(conf["lr_min"])
lr_decay = Float64(conf["lr_decay"])
lr_decay_step = Int64(conf["lr_decay_step"])
w_decay = Float64(conf["w_decay"])

llb = lb;

const l_exp = 1:14
n_exp = length(l_exp)

l_train = []
l_val = []
for i = 1:n_exp
    j = l_exp[i]
    if !(j in [2, 6, 9, 12])
        push!(l_train, i)
    else
        push!(l_val, i)
    end
end

opt = Flux.Optimiser(
    ExpDecay(lr_max, lr_decay, length(l_train) * lr_decay_step, lr_min),
    ADAMW(0.005, (0.9, 0.999), w_decay),
);

if !is_restart
    if ispath(fig_path)
        rm(fig_path, recursive = true)
    end
    if ispath(ckpt_path)
        rm(ckpt_path, recursive = true)
    end
end

if ispath("./results") == false
    mkdir("./results")
end

if ispath("./results/$expr_name") == false
    mkdir("./results/$expr_name")
end

if ispath(fig_path) == false
    mkdir(fig_path)
    mkdir(string(fig_path, "/conditions"))
end

if ispath(ckpt_path) == false
    mkdir(ckpt_path)
end

cp("./config.yaml", config_path; force=true)
DesmondYuan commented 3 years ago

Implemented at 5d910a322725ed8240fae6ffbee0531fc30c2baf

https://github.com/jiweiqi/CellBox.jl/blob/5d910a322725ed8240fae6ffbee0531fc30c2baf/runtime.yaml#L1-L2