huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.2k stars 219 forks source link

Good strategies for hierarchical classification with many classes #552

Open miguelwon opened 1 month ago

miguelwon commented 1 month ago

I'm working in a hierarchical multi class problem, and if I flat the labels (flat approach) I have about 1193 classes, which perhaps can already be consider a extreme multi classification problem. Furthermore, per class I have less than 10 examples per unique class.

With so many classes, I can't go with pairs for all combination, because it will result in a huge amount of pairs and I'm a bit limit in hardware and time.

Also, since is hierarchical I think it would work better if I privilege pairs with examples with the same "father", because I want to have a good discrimination even between example within the same "father" category.

Do you know any good strategy to this kind of problem? Perhaps train first between some random picked high level hierarchy and then further training with pairs that share the same root?

haukelicht commented 1 week ago

I have a similar use case and was thinking about implementing the method proposed in "A Multi-task Approach to Neural Multi-label Hierarchical Patent Classification Using Transformers" (doi).

The paper authors provide a implementation using keras: https://github.com/boschresearch/hierarchical_patent_classification_ecir2021/blob/main/text_classification/model/THMM.py

You could adapt their code to torch and subclass the classification head of SetfitModel as described here: https://huggingface.co/docs/setfit/en/how_to/classification_heads#custom-differentiable-head

miguelwon commented 1 week ago

Thanks @haukelicht for the reference. I'll have a look.

But this is what I have done. I built pairs from datapoints having the same common hierarchical "father". I did it to generate pairs somewhat related (they share the same high level class) but that I know they should be classified differently. These paris are like hard negatives, and make the task to distinguish them harder. Since I built the pairs from the combination of only examples with the same high level class, the final total number of pairs is significantly reduced.

Then, fine-tuned a retrieval model (I worked with gte-multilingual-base), followed to train a head with a simple NN.

With this approach I was able to achieved a good model evaluation.

haukelicht commented 1 week ago

Sounds great, @miguelwon! Can you maybe point me to the class or method you changed/subclassed to change how setfit constructs the pairwise data?

miguelwon commented 1 week ago

I didn't use setfit. Since I want such custom setup I did code myself. Is a bit of a mess but I will copy it here just for you to have an idea.

Suppose you have a list of dicts in main_train, where the value of "title" contains the full hierarchy:

main_train_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for row in main_train:
    full_title = row['title']

    chapter,section,form_title = full_title.split("|")
    chapter = chapter.strip()
    section = section.strip()
    form_title = form_title.strip()  
    text = row['text']

    main_train_dict[chapter][section][form_title].append(text)

then to build the pairs I have the following code:

def create_pairs(main_train_dict, same_section_ratio=1.0, other_chapter_ratio=0.3):
    positive_pairs = []
    negative_pairs = []

    for chapter in main_train_dict:
        for section in main_train_dict[chapter]:
            section_texts = []
            for form_title in main_train_dict[chapter][section]:
                texts = main_train_dict[chapter][section][form_title]

                # Create positive pairs within the same form_title
                for pair in combinations(texts, 2):
                    positive_pairs.append((pair[0], pair[1], 1))

                section_texts.extend([(text, form_title) for text in texts])

            # Create negative pairs within the same section
            for (text1, title1), (text2, title2) in combinations(section_texts, 2):
                if title1 != title2:
                    negative_pairs.append((text1, text2, 0))

    # Create negative pairs from other chapters
    all_texts = [(text, chapter, section, form_title) 
                 for chapter in main_train_dict 
                 for section in main_train_dict[chapter] 
                 for form_title in main_train_dict[chapter][section] 
                 for text in main_train_dict[chapter][section][form_title]]

    other_chapter_negatives = []
    for (text1, ch1, sec1, _), (text2, ch2, sec2, _) in combinations(all_texts, 2):
        if ch1 != ch2:
            other_chapter_negatives.append((text1, text2, 0))

    # Balance the dataset
    total_pairs = len(positive_pairs)
    num_same_section = int(total_pairs * same_section_ratio)
    num_other_chapter = int(total_pairs * other_chapter_ratio)

    negative_pairs = random.sample(negative_pairs, num_same_section)
    other_chapter_negatives = random.sample(other_chapter_negatives, num_other_chapter)

    all_pairs = positive_pairs + negative_pairs + other_chapter_negatives
    random.shuffle(all_pairs)

    return all_pairs

# Create the pairs
pairs = create_pairs(main_train_dict)

# Print some statistics
positive_count = sum(1 for _, _, label in pairs if label == 1)
negative_count = sum(1 for _, _, label in pairs if label == 0)

print(f"Total pairs: {len(pairs)}")
print(f"Positive pairs: {positive_count}")
print(f"Negative pairs: {negative_count}")

Do the same for the test set and then

# Prepare train data
train_examples = [InputExample(texts=[text1, text2], label=float(label)) for text1, text2, label in pairs]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# Prepare test data
test_examples = [InputExample(texts=[text1, text2], label=float(label)) for text1, text2, label in test_pairs]
test_evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(test_examples, name='test_evaluation')

And train with:

# Initialize the model
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base',trust_remote_code=True)

# Define the loss
train_loss = losses.CosineSimilarityLoss(model)

# Train the model
num_epochs = 3
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=test_evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path='./results')

So, then after this you have a gte finetuned for your classes. Then, you can easily use sklearn for example to train a NN or a logistic regression for the gte embeddings.