google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
305 stars 36 forks source link

Weights of Restored model after checkpointing do not give same loss as model before saving #1402

Open delara38 opened 10 hours ago

delara38 commented 10 hours ago

Hello,

I have a class that contains an nnx.Module and trains it. I try to save and restore by accessing this attribute but as the title says I find that when I restore the model, it's loss is as bad as a randomly initialized model.

I have no way to describe the problem as anything or than the title says because I will train a model, halve the loss from it's initialization, save the model using the instructions in the tutorial on saving and loading models (or the instructions given here https://github.com/google/flax/issues/4383, or the instructions on the orbax website) and then restore them in another file and re-run the training loop. However at the final step my loss is the same as the loss I got at initialization. Note, that the parameters are not the ones I had at initialization but completely different ones that are equally poor when evaluated on my objective function.

I have attatched the code for my model, my training file, and my loading function.

Model file:

@nnx.jit
def training_step(model, optimizer, key, X, Y, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, maxT):
    batch_size = X.shape[0] 
    t_key, key = random.split(key)
    t = random.randint(t_key, (batch_size//2+1), 0, maxT)

    t = jnp.concatenate([t, maxT-t-1], axis=-1)[:batch_size]

    ab_s = jnp.take(alphas_bar_sqrt, t)[..., None]
    am1 = jnp.take(one_minus_alphas_bar_sqrt, t)[..., None]

    noise_key, key = random.split(key)
    noise = random.normal(noise_key, X.shape)

    x_ts = ab_s*X + am1*noise

    def conditional_model_estimation_loss(model):

        preds = model(x_ts, Y, t[..., None])

        loss = jnp.mean(jnp.square(preds-noise))
        return loss 

    loss, grads = nnx.value_and_grad(conditional_model_estimation_loss)(model)

    optimizer.update(grads)
    return loss

def make_beta_schedule(schedule = 'linear', T=100, start=1e-5, end = 0.5e-2):
    if schedule == 'linear':
        betas = jnp.linspace(start, end, T)
    elif schedule == 'cosine':
        fn = lambda x: jnp.cos(x/T + np.pi/2).pow(2)
        betas = fn(jnp.linspace(0,1, T))
    elif schedule == "sigmoid":
        betas = jnp.linspace(-6, 6, T)

        betas = jax.nn.sigmoid(betas) * (end - start) + start

    else:
        raise ValueError("schedule not implemented yet")
    return betas

class ConditionalLinear(nnx.Module):
    def __init__(self, num_in, num_out, n_steps, rngs):
        self.lin = nnx.Linear(num_in, num_out, rngs = rngs)
        self.embed = nnx.Embed(n_steps, num_out, rngs = rngs)

    def __call__(self, x, t):
        xout = self.lin(x)
        em = jnp.reshape(self.embed(t), xout.shape)
        #em = jnp.squeeze(self.embed(t), -2)
        out = xout*em
        #print(xout.shape, out.shape, em.shape, x.shape)
        return out 

class ConditionalDiffusionModel(nnx.Module):
    def __init__(self, dim, conditioning_dim, hidden_dim, T, rngs):
        self.dim = dim 

        self.cond_emb1 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb2 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs)
        self.cond_emb3 = ConditionalLinear(conditioning_dim, hidden_dim, T, rngs) 

        self.l1 = ConditionalLinear(dim+hidden_dim, hidden_dim, T, rngs)
        self.l2 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.l3 = ConditionalLinear(hidden_dim, hidden_dim, T, rngs)
        self.last_layer = nnx.Linear(hidden_dim, dim, rngs=rngs)

    def __call__(self, x, y, ts):
        yemb = nnx.softplus(self.cond_emb1(y, ts))
        yemb = nnx.softplus(self.cond_emb2(y, ts))
        yemb = nnx.softplus(self.cond_emb3(y, ts))

        xus = jnp.concatenate([x, yemb], axis=-1)
        xus = nnx.softplus(self.l1(xus, ts))
        xus = nnx.softplus(self.l2(xus, ts))
        xus = nnx.softplus(self.l3(xus, ts))

        preds = self.last_layer(xus)

        return preds

class ConditionalDiffuser:
    def __init__(self, dim, conditioning_dim, hidden_dim, beta_schedule, T, sigma, rngs):
        self.model = ConditionalDiffusionModel(dim, conditioning_dim, hidden_dim, T, rngs)
        self.betas = jnp.array(make_beta_schedule(beta_schedule, T))
        self.alphas = 1-self.betas 
        self.alpha_bars = jnp.cumprod(self.alphas)
        self.alphas_bar_sqrt = jnp.sqrt(self.alpha_bars)
        self.one_minus_alphas_bar_sqrt = jnp.sqrt(1-self.alpha_bars)
        self.alpha_bars_p = jnp.concatenate([jnp.array([1]), self.alpha_bars[:-1]])
        self.T = T
        self.sigma = sigma 

    def __call__(self, x):
        return self.model(x)

    def get_model(self):
        return self.model
    def train(self, key, dataset, opt, batch_size, epochs):
        epoch_loss = [0]
        ema_loss = None 
        ema = 0.9

        for e in range(epochs):
            if e % 10 == 0:
                print(f"Epoch: {e}\\t: epoch loss {epoch_loss[-1]}, ema loss {ema_loss}")

            el = []
            key, perm_key = random.split(key)
            permutation = random.permutation(perm_key, dataset['samples'].shape[0])

            with tqdm(range(0, dataset['samples'].shape[0], batch_size)) as tp:
                for i in tp:

                    indices = permutation[i:i+batch_size]
                    X = dataset['samples'][indices]
                    Y = dataset['conditioners'][indices]

                    step_key, key = random.split(key)
                    loss = training_step(self.model, opt, step_key, X, Y, self.alphas_bar_sqrt, self.one_minus_alphas_bar_sqrt, self.T)

                    el.append(loss)
                    if ema_loss is None:
                        ema_loss = loss 
                    else:
                        ema_loss = ema*ema_loss + (1-ema)*loss 

                    tp.set_postfix(loss = np.mean(el[-40:]))
            el = np.mean(el)
            epoch_loss.append(el)

    def backward(self, x, y, ts, key):

        if type(self.sigma) == list:
            sigmas = np.choose(self.sigma, ts)
        else:
            sigmas = self.sigma

        a_roots = jnp.sqrt(jnp.take(self.alphas,(ts-1).astype(jnp.int32))[..., None])
        betas = jnp.take(self.betas, ts.astype(jnp.int32))[..., None]
        one_minus_abar_roots = jnp.sqrt(1 - jnp.take(self.alpha_bars, ts.astype(jnp.int32))[..., None])

        sigma_ts = jnp.sqrt( betas )

        pred_noise = self.model(x, y, ts.astype(jnp.int32))

        n_key, key = random.split(key)
        x_t = (1/a_roots) * (x - ( betas / one_minus_abar_roots )*pred_noise) + sigma_ts * random.normal(n_key, x.shape)

        return x_t

    def complete_backward(self, x,y, T, key):
        for t in range(1,T):

            #xus = torch.concatenate([x, torch.ones(x.shape[0], 1)*(T-t)], dim=-1)
            diff_key, key = random.split(key)
            ts = (jnp.ones((x.shape[0], ))*(T-t))
            x = self.backward(x,y, ts, diff_key )
        return x

Training file

def main(args):

    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    env = gym.make(args.env)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]

    dataset = np.load(args.dataset_path)

    obs = dataset['observations']
    actions = dataset['actions']
    rew_to_go = dataset['reward_to_go']
    print(rew_to_go.shape)
    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

    if args.restore:
            print("Loading from Checkpoint")
            ckpt_path = Path('saved_models/my_checkpoints/').resolve()
            diff_model = load_model(ckpt_path)
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
            model.model = diff_model

    else:
            print("Building Fresh Model")
            model = ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1))
        #model = ConditionalDiffusionModel(action_dim, state_dim+1, 1000, 8, 'linear', 52, 0.05).to(device)
        #model.load_state_dict(torch.load(args.model_out))

    opt = nnx.Optimizer(model.model, optax.adam(args.lr))

    Y = jnp.array(np.concatenate([obs, rew_to_go], axis=-1))

    X = jnp.array(actions)
    data = {'samples': X, 'conditioners': Y}

     train_key = jax.random.key(1)
     model.train(train_key, data, opt, args.batch_size, args.epochs)

        # Prepare state to save
     _,  state = nnx.split(model.model)

     print("Checkpointing...")
        # Save using checkpoint manager

        # Checkpointing
     ckpt_path = Path('saved_models/my_checkpoints/').resolve()
        #ckpt_path.mkdir(parents=True, exist_ok=True)  # Ensure directory exists

     checkpointer = ocp.StandardCheckpointer()
     checkpointer.save(ckpt_path/'attempt_8', state)

     sus_model = load_model(ckpt_path)
     _, restored_state = nnx.split(sus_model)
     assert(jax.tree.map(np.testing.assert_array_equal, restored_state, state))
        #print(other_state)

     print("Done checkpointing")
    return

Load function

def load_model(checkpoint_dir, env_name='Pendulum-v1'):
    # Initialize environment to get dimensions
    env = gym.make(env_name)
    action_dim = env.action_space.shape[0]
    state_dim = env.observation_space.shape[0]

    # Initialize model with same parameters as during training
    model = ConditionalDiffuser(
        dim=action_dim,
        conditioning_dim=state_dim+1,
        hidden_dim=1000,
        beta_schedule='linear',
        T=52,
        sigma=0.05,
        rngs=nnx.Rngs(0, noise=1)
    )

    abstract_model = nnx.eval_shape(lambda: ConditionalDiffuser(dim=action_dim, conditioning_dim=state_dim+1, hidden_dim= 300, beta_schedule='linear', T=252, sigma=0.1, rngs=nnx.Rngs(0, noise=1)))
    graphdef, abstract_state = nnx.split(abstract_model)

    checkpointer = ocp.StandardCheckpointer()
    loaded_state = checkpointer.restore(checkpoint_dir/'attempt_8', abstract_state)

    #jax.tree.map(np.testing.assert_array_equal, abstract_state, loaded_state)
    #print(loaded_state)
    #print(nnx.display(loaded_state))
    model = nnx.merge(graphdef, loaded_state)

    return model