facebookresearch / TaBERT

This repository contains source code for the TaBERT model, a pre-trained language model for learning joint representations of natural language utterances and (semi-)structured tables for semantic parsing. TaBERT is pre-trained on a massive corpus of 26M Web tables and their associated natural language context, and could be used as a drop-in replacement of a semantic parsers original encoder to compute representations for utterances and table schemas (columns).
Other
580 stars 63 forks source link

How to Fine-Tune for a Binary Classification Task? #8

Open EhsanM4t1qbit opened 3 years ago

EhsanM4t1qbit commented 3 years ago

Hi, thanks for this great work. I'm trying to use your package for a binary task, and I'd like to have your feedback about how I'm using it for my problem. The task is to determine if a given context belongs to a given table (binary 0/1). Following the example posted on the homepage, I am using the enocde() method to get the context and table embeddings, concatenate and average them, and then pass the average tensor through a linear layer followed by sigmoid and cross entropy. I have a few specific questions 1- Is encode the right method for training, or is it meant to be used for inference only? 2- encode receives lists of lists as context and tables. By inspecting the code, I can see that they are converted to Tensors internally. Does this mean I can't use a PyTorch Dataset in my training loop? Currently, I'm using lists. 3- I couldn't figure out how to access the output representation for the CLS token. As mentioned, I am concatenating and averaging the table and context embeddings. Is there a CLS token I can use instead? 4- Currently, I get CUDA out of memory error after 617 steps of batch size 8. I know this is caused by the batch of data and context inside my training loop. If I move them out of the loop and feed the model a single batch of training over and over in the loop, the problem goes away.

Here is a snippet of my code that captures the process.

class TabertForClassification(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._tabert_model = TableBertModel.from_pretrained(model_name_or_path="tabert_base_k3/model.bin",
                                               config_file="tabert_base_k3/tb_config.json")
        self._linear_layer = torch.nn.Linear(in_features=768, out_features=1)

    def forward(self, batch_tables, batch_context):
        context_encoding, column_encoding, info_dict = self._tabert_model.encode(contexts=batch_context, tables=batch_tables)
        aggregate_encoding = torch.cat([context_encoding, column_encoding], dim=1).mean(dim=1)
        logits = self._linear_layer(aggregate_encoding)
        return logits

model = TabertForClassification()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 1
global_step = 0
for _ in trange(num_epochs, desc="Epoch"):
    for step, batch in enumerate(tqdm(tabert_train_set_batched, desc="Iteration")):
        # If moved out of the loop, there is no cuda error
        batch_tables = [ex['table'] for ex in batch]
        batch_context = [tokenizer.tokenize(ex['question']) for ex in batch]
        batch_labels = [ex['label'] for ex in batch]
        batch_labels = torch.Tensor(batch_labels).view(-1, 1).to(device)

        logits = model(batch_tables, batch_context)
        loss = loss_fn(logits, target=batch_labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        global_step += 1

Your feedback is greatly appreciated.

YaooXu commented 3 years ago

Have you finished the fine-tune task you describe above? I'm also interested in it, and I have the same questions as you have. I'd appreciate it if you can share the way to solve them. : )

AhmedMasryKU commented 3 years ago

I am also interested in a similar task. Could you figure it out? I noticed that the context encoding size equals the number of input words + 1. Maybe that extract token is the CLS