luoqiaoyang / ACL2021-LaSAML

The repo for ACL2021 findings paper - Don't Miss the Labels: Label-semantic Argumented Meta-Learner for Few-Shot Text Classification
15 stars 1 forks source link

task_data = self.data[randm_domain_ids] #2

Open YvoGao opened 1 year ago

YvoGao commented 1 year ago

File "E:\code_git\few-shot\ACL2021-LaSAML-main\src\dataset\parallel_sampler_new.py", line 417, in get_epoch task_data = self.data[randm_domain_ids]

当我使用了 args.embedding = 'ebdnew'

args.maml = True

args.cuda = -1
args.sup_feature = 'cls'
args.que_feature = 'cls'
args.addCtagSup = 'one'
args.addCtagQue = 'none'
args.classifier = 'mbc'

这些参数后,为什么会出现这个问题呢?觉得dataset\parallel_sampler_new.py这个文件求task_data时候有问题

luoqiaoyang commented 1 year ago

你好, 我没看明白你的报错信息, 方便发下具体报错信息吗? 从你已发的信息来看, 建议你print self.args.test_domains 和 line 412中的randm_domain_ids看看是否有问题, 另外只有clinc150数据集需要考虑cross domain的情况.

YvoGao commented 1 year ago
font{
    line-height: 1.6;
}
ul,ol{
    padding-left: 20px;
    list-style-position: inside;
}

qiaoyang学者,非常感谢您的回信,上一次的问题已经解决,但是我出现了新的问题,我的具体报错如下,我除了修改代码超参数,还未进行其他参数修改。File "E:/code_git/ACL2021-LaSAML/src/main.py", line 471, in <module>

    main()   File "E:/code_git/ACL2021-LaSAML/src/main.py", line 420, in main     train_utils.train(train_data, val_data, model, args)   File "E:\code_git\ACL2021-LaSAML\src\train\factory.py", line 10, in train     return regular.train(train_data, val_data, model, args)   File "E:\code_git\ACL2021-LaSAML\src\train\regular.py", line 72, in train     train_one(task, model, opt, args, grad, train_loss)   File "E:\code_git\ACL2021-LaSAML\src\train\regular.py", line 168, in train_one     XS = model'ebd'   File "C:\Users\yunlongG.conda\envs\MLADL\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl     result = self.forward(*input, kwargs)   File "E:\code_git\ACL2021-LaSAML\src\embedding\cxtebd_new.py", line 323, in forward     return self.get_bert(data)   File "E:\code_git\ACL2021-LaSAML\src\embedding\cxtebd_new.py", line 145, in get_bert     text_embedding_output = self.get_bert_ebd(text_token_ids)   File "E:\code_git\ACL2021-LaSAML\src\embedding\cxtebd_new.py", line 66, in get_bert_ebd     input_ids=token_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds   File "C:\Users\yunlongG.conda\envs\MLADL\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl     result = self.forward(*input, *kwargs)   File "C:\Users\yunlongG.conda\envs\MLADL\lib\site-packages\transformers\models\bert\modeling_bert.py", line 230, in forward     inputs_embeds = self.word_embeddings(input_ids)   File "C:\Users\yunlongG.conda\envs\MLADL\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl     result = self.forward(input, kwargs)   File "C:\Users\yunlongG.conda\envs\MLADL\lib\site-packages\torch\nn\modules\sparse.py", line 147, in forward     self.norm_type, self.scale_grad_by_freq, self.sparse)   File "C:\Users\yunlongG.conda\envs\MLADL\lib\site-packages\torch\nn\functional.py", line 1913, in embedding     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Input, output and indices must be on the current device配置如下args.dataset = "huffpost"args.data_path = "../data/huffpost.json"args.n_train_class = 20args.n_val_class = 5args.n_test_class = 16args.n_train_domain = 1args.n_val_domain = 1args.n_test_domain = 1args.pretrained_bert = 'bert-base-uncased'args.bert_cache_dir = '~/.pytorch_pretrained_bert/'args.finetune_ebd = Trueargs.sup_feature = 'cls'args.que_feature = 'cls'args.lr = '2e-5'args.seed = 330args.addCtagSup = 'one'args.addCtagQue = 'none'args.notqdm = Trueargs.cuda = -1args.way = 5args.shot = 5args.query = 4args.mode = 'train'args.embedding = 'ebdnew'args.classifier = 'mbc'

    ,因为我直接用pycharm本地运行,我将main文件附上,希望您能抽空帮我解答,十分感谢

---- Replied Message ----

     From 

        ***@***.***>

     Date 

    2/25/2023 12:27

     To 

        ***@***.***>

     Cc 

        ***@***.***>
        ,

        ***@***.***>

     Subject 

          Re: [luoqiaoyang/ACL2021-LaSAML] task_data = self.data[randm_domain_ids] (Issue #2)

你好, 我没看明白你的报错信息, 方便发下具体报错信息吗? 从你已发的信息来看, 建议你print self.args.test_domains 和 line 412中的randm_domain_ids看看是否有问题, 另外只有clinc150数据集需要考虑cross domain的情况.

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

luoqiaoyang commented 1 year ago

你的报错信息是RuntimeError: Input, output and indices must be on the current device, 建议你检查本地运行环境中GPU是否已经加载模型, 需注意在our.sh中, 默认指定CUDA_VISIBLE_DEVICES=0, 请根据需要进行修改.