Orbax provides common checkpointing and persistence utilities for JAX users
Weights of Restored model after checkpointing do not give same loss as model before saving #1402

I have a class that contains an nnx.Module and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.

I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.

I have attatched the code for my model, my training file, and my loading function.

Model file:

def training_step(model, optimizer, key, X, Y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, maxT):
    batch_size = X.shape[0] 
    t_key, key = random.split(key)
    t = random.randint(t_key, (batch_size//2+1), 0, maxT)

    t = jnp.concatenate([t, maxT-t-1], axis=-1)[:batch_size]

    ab_s = jnp.take(alphas_bar_sqrt, t)[..., None]
    am1 = jnp.take(one_minus_alphas_bar_sqrt, t)[..., None]

    noise_key, key = random.split(key)
    noise = random.normal(noise_key, X.shape)

    x_ts = ab_s*X + am1*noise

    def conditional_model_estimation_loss(model):

        preds = model(x_ts, Y, t[..., None])

        loss = jnp.mean(jnp.square(preds-noise))
        return loss 

    loss, grads = nnx.value_and_grad(conditional_model_estimation_loss)(model)

    return loss

def make_beta_schedule(schedule = 'linear', T=100, start=1e-5, end = 0.5e-2):
    if schedule == 'linear':
        betas = jnp.linspace(start, end, T)
    elif schedule == 'cosine':
        fn = lambda x: jnp.cos(x/T + np.pi/2).pow(2)
        betas = fn(jnp.linspace(0,1, T))
    elif schedule == "sigmoid":
        betas = jnp.linspace(-6, 6, T)

        betas = jax.nn.sigmoid(betas) * (end - start) + start

        raise ValueError("schedule not implemented yet")
    return betas

class ConditionalLinear(nnx.Module):
    def __init__(self, num_in, num_out, n_steps, rngs):
        self.lin = nnx.Linear(num_in, num_out, rngs = rngs)
        self.embed = nnx.Embed(n_steps, num_out, rngs = rngs)

    def __call__(self, x, t):
        xout = self.lin(x)
        em = jnp.reshape(self.embed(t), xout.shape)
        #em = jnp.squeeze(self.embed(t), -2)
        out = xout*em
        #print(xout.shape, out.shape, em.shape, x.shape)
        return out 

class ConditionalDiffusionModel(nnx.Module):
    def __init__(self, dim, conditioning_dim, hidden_dim, T, rngs):
        self.dim = dim 

        self.cond_emb1 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb2 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb3 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs) 

        self.l1 = ConditionalLinear(dim+hidden_dim, hidden_dim, T, rngs)
        self.l2 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.l3 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.last_layer = nnx.Linear(hidden_dim, dim, rngs=rngs)

    def __call__(self, x, y, ts):
        yemb = nnx.softplus(self.cond_emb1(y, ts))
        yemb = nnx.softplus(self.cond_emb2(y, ts))
        yemb = nnx.softplus(self.cond_emb3(y, ts))

        xus = jnp.concatenate([x, yemb], axis=-1)
        xus = nnx.softplus(self.l1(xus, ts))
        xus = nnx.softplus(self.l2(xus, ts))
        xus = nnx.softplus(self.l3(xus, ts))

        preds = self.last_layer(xus)

        return preds

class ConditionalDiffuser:
    def __init__(self, dim, conditioning_dim, hidden_dim, beta_schedule, T, sigma, rngs):
        self.model = ConditionalDiffusionModel(dim, conditioning_dim, hidden_dim, T, rngs)
        self.betas = jnp.array(make_beta_schedule(beta_schedule, T))
        self.alphas = 1-self.betas 
        self.alpha_bars = jnp.cumprod(self.alphas)
        self.alphas_bar_sqrt = jnp.sqrt(self.alpha_bars)
        self.one_minus_alphas_bar_sqrt = jnp.sqrt(1-self.alpha_bars)
        self.alpha_bars_p = jnp.concatenate([jnp.array([1]), self.alpha_bars[:-1]])
        self.T = T
        self.sigma = sigma 

    def __call__(self, x):
        return self.model(x)

    def get_model(self):
        return self.model
    def train(self, key, dataset, opt, batch_size, epochs):
        epoch_loss = [0]
        ema_loss = None 
        ema = 0.9

        for e in range(epochs):
            if e % 10 == 0:
                print(f"Epoch: {e}\\t: epoch loss {epoch_loss[-1]}, ema loss {ema_loss}")

            el = []
            key, perm_key = random.split(key)
            permutation = random.permutation(perm_key, dataset['samples'].shape[0])

            with tqdm(range(0, dataset['samples'].shape[0], batch_size)) as tp:
                for i in tp:

                    indices = permutation[i:i+batch_size]
                    X = dataset['samples'][indices]
                    Y = dataset['conditioners'][indices]

                    step_key, key = random.split(key)
                    loss = training_step(self.model, opt, step_key, X, Y, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.T)

                    if ema_loss is None:
                        ema_loss = loss 
                        ema_loss = ema*ema_loss + (1-ema)*loss 

                    tp.set_postfix(loss = np.mean(el[-40:]))
            el = np.mean(el)

    def backward(self, x, y, ts, key):

        if type(self.sigma) == list:
            sigmas = np.choose(self.sigma, ts)
            sigmas = self.sigma

        a_roots = jnp.sqrt(jnp.take(self.alphas,(ts-1).astype(jnp.int32))[..., None])
        betas = jnp.take(self.betas, ts.astype(jnp.int32))[..., None]
        one_minus_abar_roots = jnp.sqrt(1 - jnp.take(self.alpha_bars, ts.astype(jnp.int32))[..., None])

        sigma_ts = jnp.sqrt( betas )

        pred_noise = self.model(x, y, ts.astype(jnp.int32))

        n_key, key = random.split(key)
        x_t = (1/a_roots) * (x - ( betas / one_minus_abar_roots )*pred_noise) + sigma_ts * random.normal(n_key, x.shape)

        return x_t

    def complete_backward(self, x,y, T, key):
        for t in range(1,T):

            #xus = torch.concatenate([x, torch.ones(x.shape[0], 1)*(T-t)], dim=-1)
            diff_key, key = random.split(key)
            ts = (jnp.ones((x.shape[0], ))*(T-t))
            x = self.backward(x,y, ts, diff_key )
        return x

Training file

def main(args):

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    env = gym.make(args.env)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]

    dataset = np.load(args.dataset_path)

    obs = dataset['observations']
    actions = dataset['actions']
    rew_to_go = dataset['reward_to_go']
    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

    if args.restore:
            print("Loading from Checkpoint")
            ckpt_path = Path('saved_models/my_checkpoints/').resolve()
            diff_model = load_model(ckpt_path)
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
            model.model = diff_model

            print("Building Fresh Model")
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
        #model = ConditionalDiffusionModel(action_dim, state_dim+1, 1000, 8, 'linear', 52, 0.05).to(device)

    opt = nnx.Optimizer(model.model, optax.adam(

    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

     train_key = jax.random.key(1)
     model.train(train_key, data, opt, args.batch_size, args.epochs)

        # Prepare state to save
     _,  state = nnx.split(model.model)

        # Save using checkpoint manager

        # Checkpointing
     ckpt_path = Path('saved_models/my_checkpoints/').resolve()
        #ckpt_path.mkdir(parents=True, exist_ok=True)  # Ensure directory exists

     checkpointer = ocp.StandardCheckpointer()'attempt_8', state)

     sus_model = load_model(ckpt_path)
     _, restored_state = nnx.split(sus_model)
     assert(, restored_state, state))

     print("Done checkpointing")

Load function

def load_model(checkpoint_dir, env_name='Pendulum-v1'):
    # Initialize environment to get dimensions
    env = gym.make(env_name)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]

    # Initialize model with same parameters as during training
    model = ConditionalDiffuser(
        rngs=nnx.Rngs(0, noise=1)

    abstract_model = nnx.eval_shape(lambda: ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1)))
    graphdef, abstract_state = nnx.split(abstract_model)

    checkpointer = ocp.StandardCheckpointer()
    loaded_state = checkpointer.restore(checkpoint_dir/'attempt_8', abstract_state), abstract_state, loaded_state)
    model = nnx.merge(graphdef, loaded_state)

    return model