lightvector / KataGo

GTP engine and self-play learning in Go
https://katagotraining.org/
Other
3.49k stars 564 forks source link

Pytorch load model from tensorflow .ckpt and play: Mismatch and Performance Issue with b18c384nbt model_pytorch and Checkpoint File #914

Open xingchengxu opened 6 months ago

xingchengxu commented 6 months ago

version of KataGo: 1.14.0 version of torch: 1.12.1+cu113

I encountered a compatibility issue when loading the kata1-b18c384nbt-s9402410496-d4158172623/model.ckpt checkpoint with the original b18c384nbt model configuration. The initial setup did not align with the model checkpoint, leading to discrepancies in the network's ResBlock, conv1x1 was not initialized in 'bottlenest2'. In addition, intermediate_value_head, and intermediate_policy_head were missing.

To address this, I adjusted the model configuration as follows:

model_config = modelconfigs.config_of_name['b18c384nbt']
model_config['use_repvgg_linear'] = True
model_config['bnorm_use_gamma'] = True
model_config['trunkfinal_use_gamma'] = False
model_config['has_intermediate_head'] = True
model_config['intermediate_head_blocks'] = 1
model_config['v1_num_channels'] = 48
model_config['interm_norm_kind'] = 'fixscaleonenorm'
device = 'cpu'

# interm_norm_kind is not defined in the original code, however if we do not define this layer the model would not correctly define the norm_intermediate_trunkfinal to have gamma, beta, running_mean, and running_std

and some tweaks in model_pytorch.py, we are able to load the .ckpt and match all parameters to the model

state_dict = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(state_dict['model'], strict=True)

Output:

However, when we try to play with the model to see whether the model is good or not with the following code:

from KataGo.python.play import *

model.eval()
board_size = 19
gs = GameState(board_size)
pla = Board.BLACK
loc = parse_coord('Q16',gs.board)
gs.board.play(pla,loc)
gs.moves.append((pla,loc))
gs.boards.append(gs.board.copy())
outputs = get_outputs(gs, rules)
loc = outputs["genmove_result"]
pla = gs.board.pla
gs.board.play(pla,loc)
gs.moves.append((pla,loc))
gs.boards.append(gs.board.copy())
ret = str_coord(loc,gs.board)

The model made an unexpected move ("H7") and displayed an overconfident policy prediction (output['policy1'][-1] was 1.0), indicating a potential misunderstanding of the game state.

I am seeking advice on resolving this issue and identifying compatible checkpoints for model_pytorch to ensure accurate model performance and behavior.

lightvector commented 6 months ago

You are loading the model the wrong way. You're not supposed to guess what config to use, the checkpoint file itself contains the correct config. See for example the proper way to load the model: https://github.com/lightvector/KataGo/blob/master/python/load_model.py#L36-L55

This is because for any network architecture, there are many small options that can be configured, see https://github.com/lightvector/KataGo/blob/master/python/modelconfigs.py#L1401-L1504. Usually a network will be using some combination of these options, so you don't want to guess, you just want to trust the config inside the checkpoint itself.

xingchengxu commented 6 months ago

Great! Thank you very much for pointing out the correct approach to loading the model. Your guidance clarifies the importance of using the configuration embedded within the checkpoint file itself, rather than attempting to deduce or manually adjust the configuration. Thanks!