FangShancheng / ABINet

Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
Other
420 stars 72 forks source link

Error(s) in loading state_dict for BCNLanguage,Missing key(s) in state_dict #58

Open VictorYang097 opened 2 years ago

VictorYang097 commented 2 years ago

Thanks for your great work! when I created an instance of BCNLanguage and let it load pretrain-language-model.pth, an error emerged:

Traceback (most recent call last): File "demo_language.py", line 128, in main() File "demo_language.py", line 97, in main model = load(model, config.model_checkpoint, device=device) File "demo_language.py", line 70, in load model.load_state_dict(state, strict=strict) File "/home/dell/anaconda3/envs/abinet/lib/python3.6/site-packages/torch/nn/modules/module.py", line 777, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for BCNLanguage: Missing key(s) in state_dict: "proj.weight", "token_encoder.pe", "pos_encoder.pe", "model.layers.0.multihead_attn.in_proj_weight", "model.layers.0.multihead_attn.in_proj_bias", "model.layers.0.multihead_attn.out_proj.weight", "model.layers.0.multihead_attn.out_proj.bias", "model.layers.0.linear1.weight", "model.layers.0.linear1.bias", "model.layers.0.linear2.weight", "model.layers.0.linear2.bias", "model.layers.0.norm2.weight", "model.layers.0.norm2.bias", "model.layers.0.norm3.weight", "model.layers.0.norm3.bias", "model.layers.1.multihead_attn.in_proj_weight", "model.layers.1.multihead_attn.in_proj_bias", "model.layers.1.multihead_attn.out_proj.weight", "model.layers.1.multihead_attn.out_proj.bias", "model.layers.1.linear1.weight", "model.layers.1.linear1.bias", "model.layers.1.linear2.weight", "model.layers.1.linear2.bias", "model.layers.1.norm2.weight", "model.layers.1.norm2.bias", "model.layers.1.norm3.weight", "model.layers.1.norm3.bias", "model.layers.2.multihead_attn.in_proj_weight", "model.layers.2.multihead_attn.in_proj_bias", "model.layers.2.multihead_attn.out_proj.weight", "model.layers.2.multihead_attn.out_proj.bias", "model.layers.2.linear1.weight", "model.layers.2.linear1.bias", "model.layers.2.linear2.weight", "model.layers.2.linear2.bias", "model.layers.2.norm2.weight", "model.layers.2.norm2.bias", "model.layers.2.norm3.weight", "model.layers.2.norm3.bias", "model.layers.3.multihead_attn.in_proj_weight", "model.layers.3.multihead_attn.in_proj_bias", "model.layers.3.multihead_attn.out_proj.weight", "model.layers.3.multihead_attn.out_proj.bias", "model.layers.3.linear1.weight", "model.layers.3.linear1.bias", "model.layers.3.linear2.weight", "model.layers.3.linear2.bias", "model.layers.3.norm2.weight", "model.layers.3.norm2.bias", "model.layers.3.norm3.weight", "model.layers.3.norm3.bias", "cls.weight", "cls.bias".

relevant functions and files are as follows: main()

def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='configs/inf_language_model.yaml', help='path to config file')

TODO 构造语言模型的输入,可以是一个txt

parser.add_argument('--input', type=str, default='langs/input.txt') parser.add_argument('--cuda', type=int, default=-1) parser.add_argument('--checkpoint', type=str, default='workdir/pretrain-language-model/pretrain-language-model.pth') parser.add_argument('--model_eval', type=str, default='language', choices=['alignment', 'vision', 'language']) args = parser.parse_args() config = Config(args.config) if args.checkpoint is not None: config.model_checkpoint = args.checkpoint if args.model_eval is not None: config.model_eval = args.model_eval config.global_phase = 'test'

config.model_vision_checkpoint, config.model_language_checkpoint = None, None

device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'

Logger.init(config.global_workdir, config.global_name, config.global_phase)
Logger.enable_file()
logging.info(config)

logging.info('Construct model.')
model = get_model(config).to(device)
model = load(model, config.model_checkpoint, device=device)
charset = CharsetMapper(filename=config.dataset_charset_path,
                        max_length=config.dataset_max_length + 1)

with open(args.input, 'r') as f:
    words = [line.strip() for line in f.readlines()]

# TODO
for word in tqdm.tqdm(words):
    word = re.sub('[^0-9a-zA-Z]+', '', word)
    if not config.dataset_eval_case_sensitive: word = word.lower()

    length_x = tensor(len(word) + 1).to(dtype=torch.long)  # one for end token
    label_x = charset.get_labels(word, case_sensitive=config.dataset_eval_case_sensitive)

    label_x = tensor(label_x)
    label_x = onehot(label_x, charset.num_classes)

    label_x = torch.unsqueeze(label_x, dim=0)
    length_x = torch.unsqueeze(length_x, dim=0)

    label_x = label_x.to(device)
    length_x = length_x.to(device)
    # print(label_x.shape)
    # print(length_x.shape)
    res = model(label_x, length_x)
    pt_text, _, __ = postprocess(res, charset, config.model_eval)
    logging.info(f'{word}: {pt_text[0]}')

get_model()

def get_model(config): import importlib names = config.model_name.split('.') module_name, class_name = '.'.join(names[:-1]), names[-1] cls = getattr(importlib.import_module(module_name), class_name) model = cls(config) logging.info(model) model = model.eval() return model

load()

def load(model, file, device=None, strict=True): if device is None: device = 'cpu' elif isinstance(device, int): device = torch.device('cuda', device) assert os.path.isfile(file) state = torch.load(file, map_location=device) if set(state.keys()) == {'model', 'opt'}: state = state['model'] model.load_state_dict(state, strict=strict) return model

inf_language_model.yaml

global: name: inf_language_model phase: test stage: pretrain-language workdir: workdir seed: ~

dataset: train: { roots: ['data/WikiText-103.csv'], batch_size: 4096 } test: { roots: ['data/WikiText-103_eval_d1.csv'], batch_size: 4096 } charset_path: data/charset_36.txt num_workers: 4 max_length: 25 # 30 image_height: 32 image_width: 128 case_sensitive: False eval_case_sensitive: False data_aug: True multiscales: False pin_memory: True smooth_label: False smooth_factor: 0.1 one_hot_y: True use_sm: False

training: epochs: 80 show_iters: 50 eval_iters: 6000 save_iters: 3000

optimizer: type: Adam true_wd: False wd: 0.0 bn_wd: False clip_grad: 20 lr: 0.0001 args: { betas: !!python/tuple [0.9, 0.999], # for default Adam } scheduler: { periods: [70, 10], gamma: 0.1, }

model: name: 'modules.model_language.BCNLanguage' language: { num_layers: 4, loss_weight: 1., use_self_attn: False } checkpoint: workdir/pretrain-language-model/pretrain-language-model.pth strict: True

input.txt

opple hav convenient spel langguage

I would appreciate it if you could help me, thanks again. @FangShancheng

VictorYang097 commented 2 years ago

Maybe I have solved this problem. Just modify load() as follows:

def load(model, file, device=None, strict=True): if device is None: device = 'cpu' elif isinstance(device, int): device = torch.device('cuda', device) assert os.path.isfile(file) state = torch.load(file, map_location=device) **# if set(state.keys()) == {'model', 'opt'}:

state = state['model']**

state = state['model'] model.load_state_dict(state, strict=strict) return model

SDUljn commented 1 year ago

@LualuOntheSea I have also encountered the same issue with loading the language model (I faced the missing key problem when testing the language model's performance on Wikitext103 separately). Your explanation is very helpful, and you have adopted a similar approach that effectively resolved my issue.