alibaba / FederatedScope

An easy-to-use federated learning platform
https://www.federatedscope.io
Apache License 2.0
1.26k stars 206 forks source link

Finetune GPT-2 on SST-2 #734

Closed shuangyichen closed 8 months ago

shuangyichen commented 8 months ago

I want to finetune GPT-2 on SST-2. But I got an error

2023-12-18 21:57:27,452 (server:881) INFO: ----------- Starting training (Round #0) ------------- Traceback (most recent call last): File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/main.py", line 67, in _ = runner.run() File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/fed_runner.py", line 401, in run self._run_simulation() File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/fed_runner.py", line 484, in _run_simulation self._handle_msg(msg) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/fed_runner.py", line 425, in _handle_msg self.client[each_receiver].msg_handlersmsg.msg_type File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/workers/client.py", line 348, in callback_funcs_for_model_para sample_size, model_para_all, results = self.trainer.train() File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/auxiliaries/decorators.py", line 7, in wrapper num_samples_train, model_para, result_metric = func( File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 238, in train num_samples = self._run_routine(MODE.TRAIN, hooks_set, File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/context.py", line 345, in wrapper res = func(self, mode, hooks_set, dataset_name) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 274, in _run_routine self._run_epoch(hooks_set) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/context.py", line 364, in wrapper res = func(self, *args, *kwargs) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 291, in _run_epoch self._run_batch(hooks_set) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/context.py", line 364, in wrapper res = func(self, args, **kwargs) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 307, in _run_batch hook(self.ctx) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/llm/trainer/trainer.py", line 91, in _hook_on_batch_forward input_ids = ctx.data_batch['input_ids'].to(ctx.device) TypeError: list indices must be integers or slices, not str

Can you give me some suggestions? Thanks!

rayrayraykk commented 8 months ago

It looks like your data format does not match the requirement (json line with instruction, input, output, category), please see https://github.com/alibaba/FederatedScope/blob/8da9f9fffc0309acbea7da52a050a59fcd791d52/federatedscope/llm/dataloader/dataloader.py#L112 for details.

shuangyichen commented 8 months ago

It looks like your data format does not match the requirement (json line with instruction, input, output, category), please see

https://github.com/alibaba/FederatedScope/blob/8da9f9fffc0309acbea7da52a050a59fcd791d52/federatedscope/llm/dataloader/dataloader.py#L112

for details.

Thanks for replying. I fix this one. But I got another error. It seems ctx.data_batch does not have 'labels'. But it does have 'input_ids' and 'attentionmask'. Traceback (most recent call last): File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/main.py", line 67, in = runner.run() File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/fed_runner.py", line 401, in run self._run_simulation() File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/fed_runner.py", line 484, in _run_simulation self._handle_msg(msg) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/fed_runner.py", line 425, in _handle_msg self.client[each_receiver].msg_handlersmsg.msg_type File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/workers/client.py", line 348, in callback_funcs_for_model_para sample_size, model_para_all, results = self.trainer.train() File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/auxiliaries/decorators.py", line 7, in wrapper num_samples_train, model_para, result_metric = func( File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 238, in train num_samples = self._run_routine(MODE.TRAIN, hooks_set, File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/context.py", line 345, in wrapper res = func(self, mode, hooks_set, dataset_name) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 274, in _run_routine self._run_epoch(hooks_set) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/context.py", line 364, in wrapper res = func(self, *args, *kwargs) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 291, in _run_epoch self._run_batch(hooks_set) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/context.py", line 364, in wrapper res = func(self, args, **kwargs) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/core/trainers/trainer.py", line 307, in _run_batch hook(self.ctx) File "/nfs/horai.dgpsrv/ondemand28/schen/csy/LoRA/new/FS-LLM/federatedscope/llm/trainer/trainer.py", line 95, in _hook_on_batch_forward labels = ctx.data_batch['labels'].to(ctx.device) KeyError: 'labels'

rayrayraykk commented 8 months ago

The labels are generated via output, you can set the output to true or false.

shuangyichen commented 8 months ago

The labels are generated via output, you can set the output to true or false.

All resolved. Thanks!