datamllab / rlcard

Reinforcement Learning / AI Bots in Card (Poker) Games - Blackjack, Leduc, Texas, DouDizhu, Mahjong, UNO.
http://www.rlcard.org
MIT License
2.78k stars 615 forks source link

Missing fields for checkpoint in DQNAgent ? #285

Closed billh0420 closed 1 year ago

billh0420 commented 1 year ago

I think the following changes need to be made for DQNAgent checkpoint:

A) for def checkpoint_attributes(self):

`

    'use_raw': self.use_raw,
    'replay_memory_init_size': self.replay_memory_init_size,
    'model_name': self.model_name,
    'save_path': self.save_path,
    'save_every': self.save_every

`

B) for def from_checkpoint(cls, checkpoint):

`

replay_memory_init_size=checkpoint['replay_memory_init_size'],
save_every=checkpoint['save_every'],
learning_rate=checkpoint['q_estimator']['learning_rate'],
model_name=checkpoint['model_name'],
save_path=checkpoint['save_path'],
use_raw=checkpoint['use_raw'],

`

daochenzha commented 1 year ago

@billh0420 Just trying to ensure I understand it correctly. Do you mean adding these fields to the dictionary?

billh0420 commented 1 year ago

Since these field values are not checkpointed (except for learning rate), their values will not be restored when reloaded. The learning rate probably is restored (didn't check but it seems like it would be). But, it wouldn't hurt to pass it (the learning rate) again to the DQNAgent init.

billh0420 commented 1 year ago

`

def checkpoint_attributes(self):
    '''
    Return the current checkpoint attributes (dict)
    Checkpoint attributes are used to save and restore the model in the middle of training
    Saves the model state dict, optimizer state dict, and all other instance variables
    '''

    return {
        'agent_type': 'DQNAgent',
        'q_estimator': self.q_estimator.checkpoint_attributes(),
        'memory': self.memory.checkpoint_attributes(),
        'total_t': self.total_t,
        'train_t': self.train_t,
        'epsilon_start': self.epsilons.min(),
        'epsilon_end': self.epsilons.max(),
        'epsilon_decay_steps': self.epsilon_decay_steps,
        'discount_factor': self.discount_factor,
        'update_target_estimator_every': self.update_target_estimator_every,
        'batch_size': self.batch_size,
        'num_actions': self.num_actions,
        'train_every': self.train_every,
        'device': self.device,
        # add the following:
        'use_raw': self.use_raw,
        'replay_memory_init_size': self.replay_memory_init_size,
        'model_name': self.model_name,
        'save_path': self.save_path,
        'save_every': self.save_every
    }

`

billh0420 commented 1 year ago

`

@classmethod
def from_checkpoint(cls, checkpoint):
    '''
    Restore the model from a checkpoint

    Args:
        checkpoint (dict): the checkpoint attributes generated by checkpoint_attributes()
    '''

    print("\nINFO - Restoring model from checkpoint...")
    agent_instance = cls(
        replay_memory_size=checkpoint['memory']['memory_size'],
        update_target_estimator_every=checkpoint['update_target_estimator_every'],
        discount_factor=checkpoint['discount_factor'],
        epsilon_start=checkpoint['epsilon_start'],
        epsilon_end=checkpoint['epsilon_end'],
        epsilon_decay_steps=checkpoint['epsilon_decay_steps'],
        batch_size=checkpoint['batch_size'],
        num_actions=checkpoint['num_actions'], 
        device=checkpoint['device'], 
        state_shape=checkpoint['q_estimator']['state_shape'],
        mlp_layers=checkpoint['q_estimator']['mlp_layers'],
        train_every=checkpoint['train_every'],
        # add the following:
        replay_memory_init_size=checkpoint['replay_memory_init_size'],
        save_every=checkpoint['save_every'],
        learning_rate=checkpoint['q_estimator']['learning_rate'],
        model_name=checkpoint['model_name'],
        save_path=checkpoint['save_path'],
        use_raw=checkpoint['use_raw']
    )

    etc.

`

daochenzha commented 1 year ago

@billh0420 It makes sense to me. I agree that it is better to restore than initiate again. Do you want me to fix it, or you fix it with a PR?

billh0420 commented 1 year ago

I will fix it with a PR. Can I delete the line "self.use_raw = False" in the init since 'use_raw' is never referenced?

daochenzha commented 1 year ago

@billh0420 Yes, sure. I agree with it, as it is never referenced.

kaiks commented 1 year ago

Hi @billh0420! I contributed the checkpoint code so I can provide some context. Thanks for your fixes.

learning_rate=checkpoint['learning_rate'],

This is actually a property of the estimator, not of the agent, and so adding it to the agent is a bit redundant.

For model_name - I don't see it included in your PR, but we already have agent_type. I wonder if the intention is the same?

For save_path and save_every it makes sense to store them if you just want to continue training, but I was building training pipelines and had to modify these every time anyway.

replay_memory_init_size is indeed an oversight - good catch. In most cases not including it won't matter, because the agent's memory will already be full at the time of the checkpoint, so the condition:

tmp = self.total_t - self.replay_memory_init_size
        if tmp>=0 #...

will always be true.

billh0420 commented 1 year ago

@kaiks said "For model_name - I don't see it included in your PR ..."

I redid DQNAgent for my own purpose (essentially same) and I added the attribute 'model_name'. RLCard doesn't have it.

@kaiks said "For save_path and save_every it makes sense to store them if you just want to continue training..."

Yes, I want to continue training. I also use save_path to know where 'fig.png' and 'log.txt' and 'performance.csv' should be stored. I also have some other related files to store in that path.