segment-any-text / wtpsplit

Toolkit to segment text into sentences or other semantic units in a robust, efficient and adaptable way.
MIT License
709 stars 40 forks source link

Questions concerning configuring train_lora.py for custom corpus #130

Open eshau opened 4 weeks ago

eshau commented 4 weeks ago

Hi! I followed the instructions for fine-tuning my corpus and (I think) managed to do so successfully after days of debugging. I have A LOT of implementation questions and the following is half-guide, half questions about the process. I want to make it very clear I am VERY thankful for the existence of this library but I feel obligated to point out the issues below.



I first created a new conda environment (with conda create -n sat-finetune python=3.9 and conda install pip just to be safe) and ran the following:

git clone https://github.com/segment-any-text/wtpsplit
cd wtpsplit
pip install -r requirements.txt
pip install adapters==0.2.1 --no-dependencies
cd ..


Then I created the .pth dataset as per this format:

import torch

torch.save(
    {
        "language_code": {
            "sentence": {
                "sat-dataset": {
                    "meta": {
                        "train_data": ["train sentence 1", "train sentence 2"],
                    },
                    "data": [
                        "test sentence 1",
                        "test sentence 2",
                    ]
                }
            }
        }
    },
    "<path>/sat-dataset.pth"
)


My config is below:

{
    "model_name_or_path": "segment-any-text/sat-3l",
    "output_dir": "sat-3l-LL_lora",
    "block_size": 256,
    "eval_stride": 128,
    "do_train": true,
    "do_eval": true,
    "per_device_train_batch_size": 64,
    "per_device_eval_batch_size": 32,
    "gradient_accumulation_steps": 1,
    "eval_accumulation_steps": 8,
    "dataloader_num_workers": 1,
    "preprocessing_num_workers": 1,
    "learning_rate": 3e-4,
    "fp16": false,
    "num_train_epochs": 30,
    "logging_steps": 50,
    "report_to": "wandb",
    "wandb_project": "sentence",
    "save_steps": 100000000,
    "remove_unused_columns": false,
    "do_sentence_training": true,
    "do_auxiliary_training": false,
    "warmup_ratio": 0.1,
    "non_punctuation_sample_ratio": null,
    "prediction_loss_only": true,
    "use_auxiliary": true,
    "ddp_timeout": 3600,
    "use_subwords": true,
    "custom_punctuation_file": "punctuation_xlmr_unk.txt",
    "log_level": "warning",
    "adapter_config": "lora[r=16,alpha=32,intermediate_lora=True]",
    "text_path": "<path>/sat-dataset.pth",
    "weight_decay": 0.0,
    "auxiliary_remove_prob": 0.0,
    "train_adapter": true
}


The first issue that popped up was that wtpsplit wasn't installed. To fix this, I added the path of the wtsplit dir to train_lora.py:

...
import wandb

# Added this line below
sys.path.insert(0, os.path.abspath('<path>/wtpsplit'))

from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
...


On hindsight, I think this can be fixed with a pip install . in the correct directory but I wasn't sure.


After this, I received an outdated CUDA version error as the torch version in the requirements.txt file by default installs 1.7.1. I tried upgrading to the version on my kernel with the recommended torch command (conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia) but this updates numpy to 2.0 among other things and causes version errors. Downgrading to torch=1.13.1+cu117 did not help (from a brief Google search the version itself is buggy) and I progressively downgraded to torch=1.10.1+cu111 to make it work.


This made CUDA work but then I got an index error along these lines:

...
(Thousands of the line below)
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [864,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

...
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: CUDA error: device-side assert triggered


I believe this is because in train_lora.py we add the newline token to the tokenizer:

...
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
# needed since we create labels in collate_fn based on tokens
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
custom_token_id = tokenizer.convert_tokens_to_ids("\n")
# used later to filter out special tokens
special_tokens_ids = set(tokenizer.all_special_ids)
special_tokens_ids.discard(custom_token_id)
...


but never update the size of the embedding in the backbone model. This leads to the tokenizer generating the input id token corresponding to the newline token (250002) but the embedding model not being big enough to accommodate it. I thought this was because I had newlines in my sentences but even after removing them I still received this error (I also later realized we added newlines anyway in prepare_dataset). To fix this, I added this line:

...
special_tokens_ids = set(tokenizer.all_special_ids)
 special_tokens_ids.discard(custom_token_id)

# Added this line below
backbone.resize_token_embeddings(len(tokenizer))

if "short" in dataset_name:
    one_sample_per_line = True
...


This led to another error I will explain further below but at this point I had a few questions:



**1. Should we have newlines in our train / valid sentences?

  1. Adding the extra embedding for newline feels like a hack and I am pretty sure the original SaT-3l I used as the base model should have had this in the embedding. Was the wrong model used as the base? And is this new embedding for newline changing enough given we freeze the weights of the original model? Furthermore, is this leaking into the TokenClassifier part of the model?**





After this, I tried running train_lora.py but the code was stuck. I did some debugging and it was stuck on this line:

...
    **kwargs,
)

# Stuck on this line below
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

logger.warning(f"Finished training for {lang} {dataset_name}.")
...


I did some digging and I think it was because I was running on 4 GPUS and the inter-GPU communication was not working. I tried several things (setting ddp_backend=gloo in the config.json file, setting os.environ["NCCL_P2P_DISABLE"]="1" on top of train_lora.py) but the only thing that fixed it was restricting CUDA to one device by setting os.environ["CUDA_VISIBLE_DEVICES"]="0" on top of train_lora.py. From my understanding of the paper and the Github Issues I have read, the paper's experiments were run on TPUs (on Amazon Sagemaker) and 1 GPU for fine-tuning so this seems like an oversight. I feel like I have to ask the question here:



1. Is there a way to run fine-tuning on multiple GPUs?





After I fixed this, the code ran but I received another error involving infinite values:


ValueError: Input contains infinity or a value too large for dtype('float64')


When I went through the traceback, I found the error to be in evaluate.py, specifically when compute_metrics in train_lora.py calls evaluate_sentencewhich in turn calls get_metrics here:

newline_probs = char_probs[:, positive_index]

# This line below
metrics, info = get_metrics(newline_labels, newline_probs, threshold=threshold)

info["newline_labels"] = newline_labels


This is because of this line in get_metrics:

def get_metrics(labels, preds, threshold: float = 0.01):
    # Compute precision-recall curve and AUC

    # This line below
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, preds)

    pr_auc = sklearn.metrics.auc(recall, precision)
    ...


because preds contains -np.inf. I did some more digging and found it was because we call token_to_char_probs here in evaluate_sentence:

...
if "xlm" in model.config.model_type:
    tokens = tokenizer.tokenize(text, verbose=False)

    # This line below
    char_probs = token_to_char_probs(text, tokens, logits, tokenizer, offsets_mapping)

else:
    ....


This is because token_to_char_probs in utils.__init__.py initiializes the return tensor char_probs as -np.inf here:

def token_to_char_probs(text, tokens, token_logits, tokenizer, offsets_mapping):
    """Map from token probabalities to character probabilities"""

    # This line below
    char_probs = np.full((len(text), token_logits.shape[1]), -np.inf)  # Initialize with very low numbers
    ...


Which because we only replace rows whose corresponding character is the last character of a non-special token:

...
 # Assign the token's probability to the last character of the token
for i in range(valid_offsets.shape[0]):
    start, end = valid_offsets[i]

    # This line below
    char_probs[end - 1] = token_logits[valid_indices[i]]

...


We are left with a lot of -np.inf in the first column when we call get_metrics:

...
# This line below
newline_probs = char_probs[:, positive_index]

metrics, info = get_metrics(newline_labels, newline_probs, threshold=threshold)
...


and sklearn.metrics.auc really does not like that. To fix this, I set the offending line in get_metrics to be sigmoid(pred) as per the call to f1score in the same function:

def get_metrics(labels, preds, threshold: float = 0.01):
    # Compute precision-recall curve and AUC

    # I changed this line below
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, sigmoid(preds))

    ....
    # Compute F1 score for a specific threshold (e.g., 0.01 after applying sigmoid)
    f1_at_specific_threshold = sklearn.metrics.f1_score(labels, sigmoid(preds) > threshold)
    ...


With all of these changes, running the initial command with train_lora.py and config.json worked. The results look OK but I have another few questions at this point:



**1. Is the auxiliary objection function necessary for fine-tuning? I tried setting "use_auxiliary": false in the config.json file but ran into warnings while training and it errored while loading with SaT.

  1. What is the difference between one_sample_per_line=True and one_sample_per_line=False? Does it actually make a difference on training results?
  2. Is the optimal threshold returned by compute_metrics really the best threshold in your empirical experience?
  3. Is corruption mandatory / desirable for validation? From my understanding the default LORA configs do not corrupt training samples but evaluate_sentence has this line:**
separator = Constants.SEPARATORS[lang_code]

# This line below
sentences = [corrupt(sentence, do_lowercase, do_remove_punct) for sentence in sentences]
text = separator.join(sentences)



Once again, thank you for reading this and I hope you can answer my questions / concerns.

markus583 commented 3 weeks ago

Hi, thanks for your detailed post. I'll go through your issues one by one.

Indeed, you don't need to add sys.path.insert(0, os.path.abspath('<path>/wtpsplit')). We instead assume that the library is already installed from pip, or, as you suggested, via pip install -e ..

As for the torch version, this depends a lot on the specific hardware you're using (including CUDA, cuDNN, OS, etc.). So I'm afraid there's not a lot I can do here. In our experiments, we used torch==1.7.1. In general, there's nothing wrong in using a newer version.

The np.inf error you were observing is due to a recent change in the library (see #127). Replacing with sigmoid(preds) should work indeed. Can you confirm this did not any other errors (e.g., ValueError) down the road?

As for the questions: 1. Should we have newlines in our train / valid sentences?

No, there should not be any newlines! You should indicate sentences in the specified format: ["train sentence 1", "train sentence 2"].

2. Adding the extra embedding for newline feels like a hack and I am pretty sure the original SaT-3l I used as the base model should have had this in the embedding. Was the wrong model used as the base? And is this new embedding for newline changing enough given we freeze the weights of the original model? Furthermore, is this leaking into the TokenClassifier part of the model?

No, we never add this to the embedding. We never modified it. Instead, the underlying dataset should not contain any newlines. We simply add the new token to the tokenizer so it can tokenize the text properly. Later, in collate_fn, we get rid of them. So the model is never exposed to any \n (added) tokens, and there will be no errors. Please check your dataset and let me know how it goes.

3. Is there a way to run fine-tuning on multiple GPUs?

In principle, it should be possible, but I currently don't have a way to test this. If you are getting errors, I suggest you look into the AdapterTrainer class. However, LoRA adaptation is generally quite fast so I never needed to use multiple devices (except if your dataset is huuuuge)

4. Is the auxiliary objection function necessary for fine-tuning? I tried setting "use_auxiliary": false in the config.json file but ran into warnings while training and it errored while loading with SaT.

No, it is not necessary. I suggest you keep "do_auxiliary_training": false and "use_auxiliary": true. It is what we used in our experiments and it worked very well.

5. What is the difference between one_sample_per_line=True and one_sample_per_line=False? Does it actually make a difference on training results?

We used one_sample_per_line=True only for verse segmentation. If you use the format as mentioned above, just set it to False.

6. Is the optimal threshold returned by compute_metrics really the best threshold in your empirical experience?

Typically, yes. It proved to work well, but there's risk for overfitting.

7. Is corruption mandatory / desirable for validation? From my understanding the default LORA configs do not corrupt training samples but evaluate_sentence has this line

No, it is not necessary. It is best to use the validation as is using your original dataset. We used this at first, but later did all the corruption on the dataset side, so we did not use the corruption in the line you mentioned any longer.

I hope this is helpful!

eshau commented 3 weeks ago

Thank you for your detailed response! I have some responses and questions to your answers.

The np.inf error you were observing is due to a recent change in the library (see https://github.com/segment-any-text/wtpsplit/issues/127). Replacing with sigmoid(preds) should work indeed. Can you confirm this did not any other errors (e.g., ValueError) down the road?

Yep! It all works out once the sigmoid is added.

1. Should we have newlines in our train / valid sentences?

No, there should not be any newlines! You should indicate sentences in the specified format: ["train sentence 1", "train sentence 2"].

I think it would be a good idea to make this explicit as a comment somewhere in train_lora.py or the README. This caused me a lot of confusion 😭

2. Adding the extra embedding for newline feels like a hack and I am pretty sure the original SaT-3l I used as the base model should have had this in the embedding. Was the wrong model used as the base? And is this new embedding for newline changing enough given we freeze the weights of the original model? Furthermore, is this leaking into the TokenClassifier part of the model?

No, we never add this to the embedding. We never modified it. Instead, the underlying dataset should not contain any newlines. We simply add the new token to the tokenizer so it can tokenize the text properly. Later, in collate_fn, we get rid of them. So the model is never exposed to any \n (added) tokens, and there will be no errors. Please check your dataset and let me know how it goes.

This was my first instinct and I initially had a dataset with no newlines (but still errored). I just tried removing them again and I receive the thousands of

/pytorch/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [824,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

errors once again.

3. Is there a way to run fine-tuning on multiple GPUs?

In principle, it should be possible, but I currently don't have a way to test this. If you are getting errors, I suggest you look into the AdapterTrainer class. However, LoRA adaptation is generally quite fast so I never needed to use multiple devices (except if your dataset is huuuuge)

Yeah I don't think I need the extra devices but it feels bad to leave compute power on the table :(

4. Is the auxiliary objection function necessary for fine-tuning? I tried setting "use_auxiliary": false in the config.json file but ran into warnings while training and it errored while loading with SaT.

No, it is not necessary. I suggest you keep "do_auxiliary_training": false and "use_auxiliary": true. It is what we used in our experiments and it worked very well.

Noted!

5. What is the difference between one_sample_per_line=True and one_sample_per_line=False? Does it actually make a difference on training results?

We used one_sample_per_line=True only for verse segmentation. If you use the format as mentioned above, just set it to False.

To clarify, my dataset is a bunch of documents that I would like to separate into sentences. Would one_sample_per_line=True be relevant here?

6. Is the optimal threshold returned by compute_metrics really the best threshold in your empirical experience?

Typically, yes. It proved to work well, but there's risk for overfitting.

Also noted, though I am curious how the auxiliary objective affects this, if at all.

7. Is corruption mandatory / desirable for validation? From my understanding the default LORA configs do not corrupt training samples but evaluate_sentence has this line

No, it is not necessary. It is best to use the validation as is using your original dataset. We used this at first, but later did all the corruption on the dataset side, so we did not use the corruption in the line you mentioned any longer.

Got it!

Thanks again!

markus583 commented 3 weeks ago

Thanks for the sigmoid/np.inf fix. I will adapt this later. 1. I'll also add that there should be no newlines in the readme. First, I'll test it myself though, but I can't do it just now.

2. It is weird that you get these errors. Would it be possible for you to share (upload) your dataset (or a tiny subset) somewhere so I can try and debug it myself?

5. No, I suggest you keep it to False.

eshau commented 3 weeks ago

Unfortunately I cannot share our dataset. I will try to reproduce the issue with a WIkipedia article. On a side note, how does SaT deal with newlines when it splits data with them?

markus583 commented 3 weeks ago

I understand that - that would be very helpful - whatever does the job! I don't get the question, could you please reformulate?

eshau commented 3 weeks ago

If I pass a document with newlines into SaT.split(), how does the model deal with the newlines?

markus583 commented 3 weeks ago

I see. The tokenizer takes care of this. If you pass a newline, it will be tokenized in the same way as a space. But we only return character probabilities, so the newline is retained in the output and the text can be reconstructed. However, we cannot do this for LoRA fine-tuning as we need the newlines in the tokenizer to later create labels in collate_fn.

mexus commented 1 hour ago

@eshau thank you for such a thorough investigation! I happened to encounter the same issues, and thanks to your findings, I was able to finally train LoRA :)

@markus583 thanks for you great work!

Regarding the issue with indexing errors, it occurs even on the dummy dataset that is mentioned on the README page:

import torch

torch.save(
    {
        "language_code": {
            "sentence": {
                "dummy-dataset": {
                    "meta": {
                        "train_data": ["train sentence 1", "train sentence 2"],
                    },
                    "data": [
                        "test sentence 1",
                        "test sentence 2",
                    ]
                }
            }
        }
    },
    "dummy-dataset.pth"
)

Although the fix (backbone.resize_token_embeddings(len(tokenizer))) from @eshau helps.