hulianyuyy / DSTA-SLR

Dynamic Spatial-Temporal Aggregation for Skeleton-Aware Sign Language Recognition (COLING2024)
Other
6 stars 2 forks source link

Training on WLASL300 #4

Open pcc03 opened 3 months ago

pcc03 commented 3 months ago

Hi author,

It's a great work! I tried to reproduce your result on WLASL dataset. But I found that the trained result is bad (top1 per instance = 1.66%, top5 per instance= 1.77%) using your preprocessed train_label.pkl and train_data_joint.npy of WLASL300.

The train.yaml is as below. Could you please give me a suggestion on how to modify the configuration to reproduce your results?


===================================================================
Experiment_name: experiment

# feeder
dataset: WLASL300 # [WLASL100, WLASL300, WLASL1000, WLASL2000, MLASL100, MLASL200, MLASL500, MLASL1000, SLR500, NMFs-CSL]
feeder: feeders.feeder.Feeder
train_feeder_args:
  debug: True
  random_choose: True
  window_size: 120  
  random_shift: True
  normalization: True
  random_mirror: True
  random_mirror_p: 0.5
  is_vector: False
  lap_pe: False
  bone_stream: False # True or False
  motion_stream: True # True or False

test_feeder_args:
  random_mirror: False
  normalization: True
  lap_pe: False
  debug: False
  random_choose: False
  window_size: 120  
  bone_stream: False # True or False
  motion_stream: True # True or False

# model
model: model.fstgan.Model
model_args:
  num_class: 300   # 100 for WLASL100, 300 for WLASL300, 1000 for WLASL1000, 2000 for WLASL2000, 500 for SLR500, 100 for MLASL100, 200 for MLASL200, 500 for MLASL500, 1000 for MLASL1000, 1067 for NMFs-CSL
  num_point: 27
  num_person: 1
  graph: graph.sign_27.Graph
  groups: 16
  block_size: 41
  graph_args:
    labeling_mode: 'spatial'
  inner_dim: 64
  depth: 4
  drop_layers: 2

#optim
weight_decay: 0.0001
base_lr: 0.1
step: [150, 200]

# training
device: [0]
# weights: save_models/wlasl_all_attn-187.pt`
# start_epoch: 188
keep_rate: 0.9
only_train_epoch: 1
batch_size: 24
test_batch_size: 24
num_epoch: 250
nesterov: True
warm_up_epoch: 20

wandb: False
wandb_project: SLGTformer First Run
wandb_entity: irvl
wandb_name: Twin Attention, No Shift, 24BS

num_worker: 4
save_interval: 5
===================================================================
hulianyuyy commented 3 months ago

Do you load the wlasl_all_attn-187.pt at the begining of training, or resume from it to restart the training?

pcc03 commented 3 months ago

I restarted the training from the beginning. Also, I didn't find the wlasl_all_attn-187.pt in the repository. Could you please let me know where I can download it?

BTW, I also tested the pretrained model "./pretrained_models/pretrained_model_for_WLASL2000.pt". The accuracy is also very low. I am confused that whether the provided one is a workable model.

Could you give some hints that whether I did wrong to train and test the WLASL?

hulianyuyy commented 3 months ago

I will check the code and respond later.

pcc03 commented 3 months ago

Thank you! Appreciate it.

hulianyuyy commented 2 months ago

I have trained on WLASL2000 and WLASL300, and obtain correct results. The log files are log_wlasl300_joint.txt and log_wlasl2000_joint.txt. But when i test with the saved weights, i also get 0.1% accuracy. I check the code and still don't find any problem. I will keep tracking the problems.

pcc03 commented 2 months ago

Thank you for your reply! By comparing the log files, I found that I need to set the phase=train. Otherwise, it is the test phase by default. I can successfully train the WLASL300 now.

But one more questions is that, when I try to train on the WLASL100 using the provided npy and pkl files, the following error appears. Could you please help me with this?

/home/mssn/DSTA-SLR/model/fstgan.py:114: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant_. nn.init.constant(self.Linear_bias, 1e-6) Attention Enabled! Attention Enabled! Attention Enabled! Attention Enabled! /home/mssn/DSTA-SLR/model/fstgan.py:791: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal_. nn.init.normal(self.fc.weight, 0, math.sqrt(2.0 / num_class)) 1442 Traceback (most recent call last): File "main.py", line 801, in <module> processor = Processor(arg) File "main.py", line 238, in __init__ self.load_data() File "main.py", line 270, in load_data dataset=Feeder( File "/home/mssn/DSTA-SLR/feeders/feeder.py", line 70, in __init__ self.load_data() File "/home/mssn/DSTA-SLR/feeders/feeder.py", line 129, in load_data self.data = np.load(self.data_path, mmap_mode="r") File "/home/mssn/anaconda3/envs/DSTA-SLR/lib/python3.8/site-packages/numpy/lib/npyio.py", line 438, in load raise ValueError("Cannot load file containing pickled data " ValueError: Cannot load file containing pickled data when allow_pickle=False

hulianyuyy commented 2 months ago

You may change line 129 in ./feeders/feeder.py into self.data = np.load(self.data_path, mmap_mode="r", allow_pickle=True) to test it.

pcc03 commented 2 months ago

I changed this line, but the new error came out as the following.

Could you let me know how you get the train_labels.pkl and train_data_joint.npy for the WLASL100? I think it would be good if I could reproduce the generated datasets.

Traceback (most recent call last):

  File "/home/mssn/anaconda3/envs/DSTA-SLR/lib/python3.8/site-packages/numpy/lib/npyio.py", line 441, in load
    return pickle.load(fid, **pickle_kwargs)
_pickle.UnpicklingError: could not find MARK

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "main.py", line 801, in <module>
    processor = Processor(arg)
  File "main.py", line 238, in __init__
    self.load_data()
  File "main.py", line 270, in load_data
    dataset=Feeder(
  File "/home/mssn/DSTA-SLR/feeders/feeder.py", line 70, in __init__
    self.load_data()
  File "/home/mssn/DSTA-SLR/feeders/feeder.py", line 130, in load_data
    self.data = np.load(self.data_path, mmap_mode="r", allow_pickle=True)
  File "/home/mssn/anaconda3/envs/DSTA-SLR/lib/python3.8/site-packages/numpy/lib/npyio.py", line 443, in load
    raise pickle.UnpicklingError(
_pickle.UnpicklingError: Failed to interpret file './data/WLASL100/val_data_joint.npy' as a pickle
hulianyuyy commented 2 months ago

I get the data for WLASL100 by squeezing the data from WLASL2000 data according to the classes of WLASL100. The occurred error seems strange, because the input data is not a pickle file but a npy file.

pcc03 commented 2 months ago

I see. Thank you for the reply! I will try to regenerate the WLASL100 files by myself.