XuhuiZhou / cobra-frames

The official code repo of paper: COBRA Frames: Contextual Reasoning about Effects and Harms of Offensive Statements(https://arxiv.org/abs/2306.01985)
https://cobra.xuhuiz.com/
6 stars 1 forks source link

ValueError raised while calling a function in metrics.py #2

Closed AmeyHengle closed 1 year ago

AmeyHengle commented 1 year ago

Hi! I'm facing an error while running the train_explain_model pipeline.

The error is occurring in the following line of the postprocess_text function in the metrics.py file: preds = np.where(labels != -100, preds, tokenizer.pad_token_id)

Seems like the np.where operation is trying to broadcast arrays with shapes (1000, 342), (1000, 512), and () (a scalar), but they are not compatible for broadcasting.

Complete Stack Trace

bash /home/ameyh/cobra-frames/scripts/explain_model/train_explain_model.sh "train_small"
Rewritten gin arg: --gin_bindings=MODEL_DIR = '.log/explain-model-small'
Rewritten gin arg: --gin_bindings=MODE = 'deployment'
I0718 18:02:12.537055 140015488536832 gin_utils.py:54] Gin Configuration:
I0718 18:02:12.539195 140015488536832 gin_utils.py:56] from __gin__ import dynamic_registration
I0718 18:02:12.539299 140015488536832 gin_utils.py:56] import __main__ as train_script
I0718 18:02:12.539385 140015488536832 gin_utils.py:56] import sbf_modeling
I0718 18:02:12.539448 140015488536832 gin_utils.py:56] import sbf_modeling.utils.data as data_utils
I0718 18:02:12.539510 140015488536832 gin_utils.py:56] import transformers
I0718 18:02:12.539570 140015488536832 gin_utils.py:56] 
I0718 18:02:12.539630 140015488536832 gin_utils.py:56] # Macros:
I0718 18:02:12.539689 140015488536832 gin_utils.py:56] # ==============================================================================
I0718 18:02:12.539749 140015488536832 gin_utils.py:56] MODE = 'deployment'
I0718 18:02:12.539808 140015488536832 gin_utils.py:56] MODEL_DIR = '.log/explain-model-small'
I0718 18:02:12.539868 140015488536832 gin_utils.py:56] 
I0718 18:02:12.539927 140015488536832 gin_utils.py:56] # Parameters for sbf_modeling.ExplainModel:
I0718 18:02:12.539986 140015488536832 gin_utils.py:56] # ==============================================================================
I0718 18:02:12.540046 140015488536832 gin_utils.py:56] sbf_modeling.ExplainModel.t5_model_name = 'google/flan-t5-small'
I0718 18:02:12.540106 140015488536832 gin_utils.py:56] 
I0718 18:02:12.540165 140015488536832 gin_utils.py:56] # Parameters for sbf_modeling.ExplainModel.train:
I0718 18:02:12.540225 140015488536832 gin_utils.py:56] # ==============================================================================
I0718 18:02:12.540284 140015488536832 gin_utils.py:56] sbf_modeling.ExplainModel.train.args = @transformers.Seq2SeqTrainingArguments()
I0718 18:02:12.540343 140015488536832 gin_utils.py:56] sbf_modeling.ExplainModel.train.print_prediction_num_examples = 300
I0718 18:02:12.540403 140015488536832 gin_utils.py:56] 
I0718 18:02:12.540462 140015488536832 gin_utils.py:56] # Parameters for data_utils.get_data:
I0718 18:02:12.540521 140015488536832 gin_utils.py:56] # ==============================================================================
I0718 18:02:12.540580 140015488536832 gin_utils.py:56] data_utils.get_data.mode = %MODE
I0718 18:02:12.540655 140015488536832 gin_utils.py:56] 
I0718 18:02:12.540717 140015488536832 gin_utils.py:56] # Parameters for transformers.Seq2SeqTrainingArguments:
I0718 18:02:12.540791 140015488536832 gin_utils.py:56] # ==============================================================================
I0718 18:02:12.540851 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.eval_steps = 200
I0718 18:02:12.540911 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.evaluation_strategy = 'epoch'
I0718 18:02:12.540971 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.generation_max_length = 512
I0718 18:02:12.541031 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.gradient_accumulation_steps = 2
I0718 18:02:12.541090 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.learning_rate = 0.0001
I0718 18:02:12.541150 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.logging_steps = 20
I0718 18:02:12.541209 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.lr_scheduler_type = 'cosine'
I0718 18:02:12.541269 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.num_train_epochs = 2
I0718 18:02:12.541328 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.output_dir = '.log/_explain_model'
I0718 18:02:12.541388 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.per_device_eval_batch_size = 16
I0718 18:02:12.541447 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.per_device_train_batch_size = 8
I0718 18:02:12.541507 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.predict_with_generate = True
I0718 18:02:12.541566 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.prediction_loss_only = False
I0718 18:02:12.541626 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.report_to = ['wandb']
I0718 18:02:12.541685 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.save_strategy = 'epoch'
I0718 18:02:12.541744 140015488536832 gin_utils.py:56] transformers.Seq2SeqTrainingArguments.weight_decay = 0.1
I0718 18:02:12.541803 140015488536832 gin_utils.py:56] 
I0718 18:02:12.541862 140015488536832 gin_utils.py:56] # Parameters for train_script.train:
I0718 18:02:12.541922 140015488536832 gin_utils.py:56] # ==============================================================================
I0718 18:02:12.541989 140015488536832 gin_utils.py:56] train_script.train.model = @sbf_modeling.ExplainModel()
I0718 18:02:12.542059 140015488536832 gin_utils.py:56] train_script.train.model_dir = %MODEL_DIR
I0718 18:02:12.542119 140015488536832 gin_utils.py:56] train_script.train.train_data = @data_utils.get_data()
Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.40k/1.40k [00:00<00:00, 114kB/s]
Downloading pytorch_model.bin: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 308M/308M [00:02<00:00, 106MB/s]
Downloading (…)neration_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 147/147 [00:00<00:00, 9.90kB/s]
Downloading spiece.model: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 792k/792k [00:00<00:00, 16.1MB/s]
Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.20k/2.20k [00:00<00:00, 596kB/s]
Downloading (…)okenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.54k/2.54k [00:00<00:00, 784kB/s]
Downloading readme: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.92k/1.92k [00:00<00:00, 5.76MB/s]
Downloading and preparing dataset csv/cmu-lti--cobracorpus to /home/ameyh/.cache/huggingface/datasets/cmu-lti___csv/cmu-lti--cobracorpus-86752befd5f33f8a/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...
Downloading data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 144M/144M [00:02<00:00, 49.0MB/s]
Downloading data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.56M/4.56M [00:01<00:00, 3.17MB/s]
Downloading data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:08<00:00,  4.29s/it]
Extracting data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 949.04it/s]
Dataset csv downloaded and prepared to /home/ameyh/.cache/huggingface/datasets/cmu-lti___csv/cmu-lti--cobracorpus-86752befd5f33f8a/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 270.77it/s]
I0718 18:02:39.749025 140015488536832 train.py:23] Training model
/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 2
wandb: You chose 'Use an existing W&B account'
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
wandb: Appending key for api.wandb.ai to your netrc file: /home/ameyh/.netrc
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in /home/ameyh/cobra-frames/wandb/run-20230718_180431-g2ehes4q
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run pious-pine-9
wandb: ⭐️ View project at https://wandb.ai/counterspeech-project/context-sbf
wandb: 🚀 View run at https://wandb.ai/counterspeech-project/context-sbf/runs/g2ehes4q
{'loss': 2.7977, 'learning_rate': 9.999366806729235e-05, 'epoch': 0.01}                                                                                                                    
{'loss': 1.9479, 'learning_rate': 9.997467387290426e-05, 'epoch': 0.02}                                                                                                                    
{'loss': 1.5771, 'learning_rate': 9.994302222763414e-05, 'epoch': 0.03}                                                                                                                    
{'loss': 1.4522, 'learning_rate': 9.989872114812555e-05, 'epoch': 0.04}                                                                                                                    
{'loss': 1.3501, 'learning_rate': 9.984178185483663e-05, 'epoch': 0.05}                                                                                                                    
{'loss': 1.2851, 'learning_rate': 9.977221876919833e-05, 'epoch': 0.06}                                                                                                                    
{'loss': 1.2298, 'learning_rate': 9.969004950996175e-05, 'epoch': 0.07}                                                                                                                    
{'loss': 1.1849, 'learning_rate': 9.959529488873567e-05, 'epoch': 0.08}                                                                                                                    
{'loss': 1.1692, 'learning_rate': 9.948797890471551e-05, 'epoch': 0.09}                                                                                                                    
{'loss': 1.1246, 'learning_rate': 9.936812873860486e-05, 'epoch': 0.1}                                                                                                                     
{'loss': 1.146, 'learning_rate': 9.923577474573119e-05, 'epoch': 0.11}                                                                                                                     
{'loss': 1.1139, 'learning_rate': 9.909095044835754e-05, 'epoch': 0.12}                                                                                                                    
{'loss': 1.1092, 'learning_rate': 9.893369252719214e-05, 'epoch': 0.13}                                                                                                                    
{'loss': 1.0918, 'learning_rate': 9.876404081209796e-05, 'epoch': 0.14}                                                                                                                    
{'loss': 1.101, 'learning_rate': 9.858203827200476e-05, 'epoch': 0.15}                                                                                                                     
{'loss': 1.07, 'learning_rate': 9.838773100402598e-05, 'epoch': 0.16}                                                                                                                      
{'loss': 1.0844, 'learning_rate': 9.818116822178348e-05, 'epoch': 0.17}                                                                                                                    
{'loss': 1.0398, 'learning_rate': 9.796240224294271e-05, 'epoch': 0.18}                                                                                                                    
{'loss': 1.0607, 'learning_rate': 9.773148847596194e-05, 'epoch': 0.19}                                                                                                                    
{'loss': 1.048, 'learning_rate': 9.748848540605854e-05, 'epoch': 0.2}                                                                                                                      
{'loss': 1.0535, 'learning_rate': 9.723345458039594e-05, 'epoch': 0.21}                                                                                                                    
{'loss': 1.0104, 'learning_rate': 9.696646059249522e-05, 'epoch': 0.22}                                                                                                                    
{'loss': 1.0336, 'learning_rate': 9.668757106587496e-05, 'epoch': 0.23}                                                                                                                    
{'loss': 1.0485, 'learning_rate': 9.63968566369238e-05, 'epoch': 0.24}                                                                                                                     
{'loss': 1.0216, 'learning_rate': 9.609439093700974e-05, 'epoch': 0.25}                                                                                                                    
{'loss': 1.0003, 'learning_rate': 9.578025057383116e-05, 'epoch': 0.26}                                                                                                                    
{'loss': 1.0033, 'learning_rate': 9.545451511201364e-05, 'epoch': 0.27}                                                                                                                    
{'loss': 1.0087, 'learning_rate': 9.511726705295817e-05, 'epoch': 0.28}                                                                                                                    
{'loss': 1.0007, 'learning_rate': 9.476859181394542e-05, 'epoch': 0.29}                                                                                                                    
{'loss': 1.0013, 'learning_rate': 9.440857770650138e-05, 'epoch': 0.3}                                                                                                                     
{'loss': 1.0001, 'learning_rate': 9.40373159140301e-05, 'epoch': 0.31}                                                                                                                     
{'loss': 0.9802, 'learning_rate': 9.365490046871912e-05, 'epoch': 0.32}                                                                                                                    
{'loss': 0.9675, 'learning_rate': 9.326142822772302e-05, 'epoch': 0.33}                                                                                                                    
{'loss': 0.9836, 'learning_rate': 9.285699884863193e-05, 'epoch': 0.34}                                                                                                                    
{'loss': 0.9882, 'learning_rate': 9.244171476423037e-05, 'epoch': 0.35}                                                                                                                    
{'loss': 0.9931, 'learning_rate': 9.201568115655342e-05, 'epoch': 0.36}                                                                                                                    
{'loss': 0.9836, 'learning_rate': 9.15790059302465e-05, 'epoch': 0.37}                                                                                                                     
{'loss': 0.9831, 'learning_rate': 9.11317996852355e-05, 'epoch': 0.39}                                                                                                                     
{'loss': 0.9798, 'learning_rate': 9.067417568871445e-05, 'epoch': 0.4}                                                                                                                     
{'loss': 0.9825, 'learning_rate': 9.020624984645738e-05, 'epoch': 0.41}                                                                                                                    
{'loss': 0.9701, 'learning_rate': 8.972814067346212e-05, 'epoch': 0.42}                                                                                                                    
{'loss': 0.932, 'learning_rate': 8.923996926393305e-05, 'epoch': 0.43}                                                                                                                     
{'loss': 0.9544, 'learning_rate': 8.874185926061082e-05, 'epoch': 0.44}                                                                                                                    
{'loss': 0.9507, 'learning_rate': 8.823393682345626e-05, 'epoch': 0.45}                                                                                                                    
{'loss': 0.9612, 'learning_rate': 8.771633059769711e-05, 'epoch': 0.46}                                                                                                                    
{'loss': 0.9373, 'learning_rate': 8.718917168124501e-05, 'epoch': 0.47}                                                                                                                    
{'loss': 0.9505, 'learning_rate': 8.665259359149132e-05, 'epoch': 0.48}                                                                                                                    
{'loss': 0.9373, 'learning_rate': 8.610673223149034e-05, 'epoch': 0.49}                                                                                                                    
{'loss': 0.9475, 'learning_rate': 8.555172585553805e-05, 'epoch': 0.5}                                                                                                                     
{'loss': 0.938, 'learning_rate': 8.498771503415541e-05, 'epoch': 0.51}                                                                                                                     
{'loss': 0.9446, 'learning_rate': 8.441484261848514e-05, 'epoch': 0.52}                                                                                                                    
{'loss': 0.9518, 'learning_rate': 8.383325370411068e-05, 'epoch': 0.53}                                                                                                                    
{'loss': 0.9466, 'learning_rate': 8.32430955943068e-05, 'epoch': 0.54}                                                                                                                     
{'loss': 0.9316, 'learning_rate': 8.264451776273104e-05, 'epoch': 0.55}                                                                                                                    
{'loss': 0.9193, 'learning_rate': 8.203767181556536e-05, 'epoch': 0.56}                                                                                                                    
{'loss': 0.9349, 'learning_rate': 8.142271145311783e-05, 'epoch': 0.57}                                                                                                                    
{'loss': 0.9283, 'learning_rate': 8.07997924308938e-05, 'epoch': 0.58}                                                                                                                     
{'loss': 0.9061, 'learning_rate': 8.016907252014646e-05, 'epoch': 0.59}                                                                                                                    
{'loss': 0.931, 'learning_rate': 7.953071146791712e-05, 'epoch': 0.6}                                                                                                                      
{'loss': 0.9291, 'learning_rate': 7.888487095657484e-05, 'epoch': 0.61}                                                                                                                    
{'loss': 0.9275, 'learning_rate': 7.823171456286589e-05, 'epoch': 0.62}                                                                                                                    
{'loss': 0.9154, 'learning_rate': 7.757140771648357e-05, 'epoch': 0.63}                                                                                                                    
{'loss': 0.9182, 'learning_rate': 7.690411765816864e-05, 'epoch': 0.64}                                                                                                                    
{'loss': 0.9184, 'learning_rate': 7.623001339735088e-05, 'epoch': 0.65}                                                                                                                    
{'loss': 0.9352, 'learning_rate': 7.5549265669343e-05, 'epoch': 0.66}                                                                                                                      
{'loss': 0.9259, 'learning_rate': 7.48620468920972e-05, 'epoch': 0.67}                                                                                                                     
{'loss': 0.9301, 'learning_rate': 7.416853112253557e-05, 'epoch': 0.68}                                                                                                                    
{'loss': 0.9132, 'learning_rate': 7.346889401246552e-05, 'epoch': 0.69}                                                                                                                    
{'loss': 0.9074, 'learning_rate': 7.276331276409106e-05, 'epoch': 0.7}                                                                                                                     
{'loss': 0.9139, 'learning_rate': 7.205196608513159e-05, 'epoch': 0.71}                                                                                                                    
{'loss': 0.9097, 'learning_rate': 7.13350341435592e-05, 'epoch': 0.72}                                                                                                                     
{'loss': 0.8924, 'learning_rate': 7.061269852196632e-05, 'epoch': 0.73}                                                                                                                    
{'loss': 0.9223, 'learning_rate': 6.988514217157487e-05, 'epoch': 0.74}                                                                                                                    
{'loss': 0.9314, 'learning_rate': 6.915254936589893e-05, 'epoch': 0.75}                                                                                                                    
{'loss': 0.9179, 'learning_rate': 6.841510565407235e-05, 'epoch': 0.76}                                                                                                                    
{'loss': 0.886, 'learning_rate': 6.767299781385357e-05, 'epoch': 0.77}                                                                                                                     
{'loss': 0.9066, 'learning_rate': 6.692641380431879e-05, 'epoch': 0.78}                                                                                                                    
{'loss': 0.9102, 'learning_rate': 6.617554271825636e-05, 'epoch': 0.79}                                                                                                                    
{'loss': 0.8958, 'learning_rate': 6.542057473427388e-05, 'epoch': 0.8}                                                                                                                     
{'loss': 0.9155, 'learning_rate': 6.466170106863015e-05, 'epoch': 0.81}                                                                                                                    
{'loss': 0.9069, 'learning_rate': 6.389911392680456e-05, 'epoch': 0.82}                                                                                                                    
{'loss': 0.9223, 'learning_rate': 6.313300645481575e-05, 'epoch': 0.83}                                                                                                                    
{'loss': 0.89, 'learning_rate': 6.236357269030211e-05, 'epoch': 0.84}                                                                                                                      
{'loss': 0.9031, 'learning_rate': 6.159100751337642e-05, 'epoch': 0.85}                                                                                                                    
{'loss': 0.8775, 'learning_rate': 6.081550659726718e-05, 'epoch': 0.86}                                                                                                                    
{'loss': 0.8803, 'learning_rate': 6.0037266358759025e-05, 'epoch': 0.87}                                                                                                                   
{'loss': 0.8925, 'learning_rate': 5.925648390844476e-05, 'epoch': 0.88}                                                                                                                    
{'loss': 0.8925, 'learning_rate': 5.847335700080178e-05, 'epoch': 0.89}                                                                                                                    
{'loss': 0.8869, 'learning_rate': 5.768808398410532e-05, 'epoch': 0.9}                                                                                                                     
{'loss': 0.8919, 'learning_rate': 5.6900863750191347e-05, 'epoch': 0.91}                                                                                                                   
{'loss': 0.8931, 'learning_rate': 5.6111895684081725e-05, 'epoch': 0.92}                                                                                                                   
{'loss': 0.8907, 'learning_rate': 5.53213796134846e-05, 'epoch': 0.93}                                                                                                                     
{'loss': 0.9074, 'learning_rate': 5.4529515758182506e-05, 'epoch': 0.94}                                                                                                                   
{'loss': 0.8895, 'learning_rate': 5.373650467932122e-05, 'epoch': 0.95}                                                                                                                    
{'loss': 0.8944, 'learning_rate': 5.2942547228612294e-05, 'epoch': 0.96}                                                                                                                   
{'loss': 0.8959, 'learning_rate': 5.214784449746174e-05, 'epoch': 0.97}                                                                                                                    
{'loss': 0.8796, 'learning_rate': 5.135259776603821e-05, 'epoch': 0.98}                                                                                                                    
{'loss': 0.9021, 'learning_rate': 5.055700845229327e-05, 'epoch': 0.99}                                                                                                                    
 50%|█████████████████████████████████████████████████████████████████████████                                                                         | 1974/3948 [09:02<08:53,  3.70it/sTraceback (most recent call last):██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [03:15<00:00,  3.01s/it]
  File "sbf_modeling/train.py", line 65, in <module>
    gin_utils.run(main)
  File "/home/ameyh/cobra-frames/sbf_modeling/gin_utils.py", line 82, in run
    app.run(
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "sbf_modeling/train.py", line 40, in main
    train_using_gin()
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "sbf_modeling/train.py", line 24, in train
    model = model.train(train_data, save_model_dir=model_dir)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/ameyh/cobra-frames/sbf_modeling/explain_model.py", line 187, in train
    trainer.train()
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 2021, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 2287, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/trainer_seq2seq.py", line 159, in evaluate
    return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 2993, in evaluate
    output = eval_loop(
  File "/home/ameyh/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 3281, in evaluation_loop
    metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
  File "/home/ameyh/cobra-frames/sbf_modeling/metrics.py", line 47, in aggregated_metrics_with_postprocess
    preds, labels = postprocess_text(tokenizer, eval_preds)
  File "/home/ameyh/cobra-frames/sbf_modeling/metrics.py", line 19, in postprocess_text
    preds = np.where(labels != -100, preds, tokenizer.pad_token_id)
  File "<__array_function__ internals>", line 180, in where
ValueError: operands could not be `broadcast` together with shapes (1000,342) (1000,512) () 
  In call to configurable 'train' (<function ExplainModel.train at 0x7f5730ed1310>)
  In call to configurable 'train' (<function train at 0x7f57e4741310>)
wandb: Waiting for W&B process to finish... (failed 1). Press Control-C to abort syncing.
wandb: 
wandb: Run history:
wandb:         train/epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:   train/global_step ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: train/learning_rate ██████████▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▂▂▂▂▁▁
wandb:          train/loss █▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb:         train/epoch 0.99
wandb:   train/global_step 1960
wandb: train/learning_rate 5e-05
wandb:          train/loss 0.9021
wandb: 
wandb: 🚀 View run pious-pine-9 at: https://wandb.ai/counterspeech-project/context-sbf/runs/g2ehes4q
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230718_180431-g2ehes4q/logs
XuhuiZhou commented 1 year ago

@ProKil Do you have any insight about this problem?

ProKil commented 1 year ago

Known issue raised by CI in the original code base: https://github.com/maartensap/context-sbf/actions/runs/4964345227/jobs/8884393741. It should be a version issue since we passed CI before.

XuhuiZhou commented 1 year ago

@ProKil Any chance we can also let @AmeyHengle see the CI run without making the repo public?

ProKil commented 1 year ago

@AmeyHengle Change Line 19 in requirements.txt to git+https://github.com/huggingface/transformers@2411f0e465e761790879e605a4256f3d4afb7f82 will temporarily fix this bug.

This passed the test in the original repo @XuhuiZhou please update this repo.

XuhuiZhou commented 1 year ago

@AmeyHengle Hey, we solved this problem, please try again

AmeyHengle commented 1 year ago

@ProKil @XuhuiZhou This worked, thanks!😊