def __iter__(self):
for example in self.df.rdd.toLocalIterator():
entities = example[self.entities_col]
width = example["width"]
height = example["height"]
words = []
bbox_raw = []
labels_raw = []
for ent in entities:
meta = ent.metadata
word = " " + meta['token']
words.extend([word])
x, w, y, h = int(meta['x']), int(meta['width']), int(meta['y']), int(meta['height'])
box = [x, y, x + w, y + h]
box = normalize_bbox(box, width, height)
bbox_raw.extend([box])
labels_raw.extend([self.label2id[ent.result.upper()]])
tokenized_inputs = self.tokenizer(
words,
boxes=bbox_raw,
max_length=self.max_length,
padding="max_length",
truncation=True,
stride=self.stride,
return_overflowing_tokens=True
)
for batch_index in range(len(tokenized_inputs["input_ids"])):
word_ids = tokenized_inputs.word_ids(batch_index=batch_index)
label = labels_raw
bbox = bbox_raw
previous_word_idx = None
label_ids = []
bbox_inputs = []
for word_idx in word_ids:
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# ignored in the loss function.
if word_idx is None:
label_ids.append(-100)
bbox_inputs.append([0, 0, 0, 0])
# We set the label for the first token of each word.
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
bbox_inputs.append(bbox[word_idx])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(label[word_idx] if self.label_all_tokens else -100)
bbox_inputs.append(bbox[word_idx])
previous_word_idx = word_idx
encoded = {"input_ids": tokenized_inputs["input_ids"][batch_index],
"labels": label_ids,
"bbox": bbox_inputs,
"attention_mask": tokenized_inputs["attention_mask"][batch_index]}
encoded = {k: torch.tensor(v) for k, v in encoded.items()}
yield encoded
class Lilt:
def __init__(self, base_model, tokenizer, params):
try:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError as e:
raise Exception(
f"Cannot import {e.name} required to inference the model. Please install {e.name} first.")
self.params = json.loads(params)
self.base_model = base_model
self.max_length = self.params.get('maxSentenceLength')
self.true_predictions = None
self.true_labels = None
use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:0' if use_cuda else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
if len(self.params["labels"]) > 0:
self.num_labels = len(self.params["labels"])
self.model = AutoModelForTokenClassification.from_pretrained(base_model, num_labels=self.num_labels,
ignore_mismatched_sizes=True)
self.update_config_labels(self.params["labels"])
else:
self.model = AutoModelForTokenClassification.from_pretrained(base_model)
self.model.to(self.device)
gc.collect()
def getDataset(self, df, inputCol, outputCol, split="train"):
return LiltNerDatasetIterable(df, tokenizer=self.tokenizer,
image_col=inputCol, entities_col=outputCol,
label2id=self.label2id, max_length=self.max_length,
label_all_tokens=self.params["labelAllTokens"],
stride=self.params["stride"])
def update_config_labels(self, labels):
self.labels = list(map(str.upper, labels))
label2id = dict()
for id, val in enumerate(labels):
label2id[val.upper()] = id
self.model.config.id2label = {v: k for k, v in label2id.items()}
self.model.config.label2id = label2id
self.label2id = label2id
def train(self, df):
logging.info('Lilt model Training.')
eval_size = self.params.get('evalSize')
label_list = self.params.get('labels')
os.environ["WANDB_DISABLED"] = "true"
train_df, eval_df = df.randomSplit([1.0 - eval_size, eval_size], seed=4)
train_dataset = self.getDataset(train_df, split="train",
inputCol="image", outputCol = "entities")
eval_dataset = self.getDataset(eval_df, split="test",
inputCol="image", outputCol = "entities")
metric = Seqeval()
return_entity_level_metrics = False
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
self.true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
self.true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
results = metric.compute(predictions=self.true_predictions, references=self.true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
training_args = TrainingArguments(output_dir="lilt_runs_english_v2",
log_level="info",
learning_rate=float(self.params['learningRate']),
evaluation_strategy="steps",
eval_steps=int(self.params['evalSteps']),
load_best_model_at_end=True,
per_device_train_batch_size=int(self.params['trainBatchSize']),
per_device_eval_batch_size=int(self.params['evalBatchSize']),
metric_for_best_model="f1",
max_steps=int(self.params['maxSteps']),
save_steps=int(self.params['saveSteps']),
)
# Initialize our Trainer
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
return trainer
Hi @NielsRogge
Thanks for the great work and effort.
Could you please provide the inference code for LiLT Visual NER?
I already have an OCR engine that get tokens and their coordinates from images.
below is the snipet code I used for the finetuning/evaluation:
`
`
It's quite urgent please.
Thanks again.