marcotcr / lime

Lime: Explaining the predictions of any machine learning classifier
BSD 2-Clause "Simplified" License
11.64k stars 1.81k forks source link

Added support for custom replacement_fn #687

Open ByrdOfAFeather opened 2 years ago

ByrdOfAFeather commented 2 years ago

Rather than removing text which can create oddities, we may want to consider ways to replace tokens that would otherwise be removed. I added a support for a custom replacement_fn, which is similar to the classifier_fn. My particular use case was using T5, as such, I modified the generation of perturbed data to be in batch style rather than going one at a time.

This solves partially #648

Example replacement_fn:

def t5_wrapper(text_as_list: List[str], masks: list[list[bool]]):
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    out_refs = []
    masker_idxs = []
    outs = []
    for mask in masks:
        local_out = ""
        local_out_ref = ""
        local_masker_idx = 0
        for idx in range(len(mask)):
            if mask[idx]:
                local_out += text_as_list[idx]
                local_out_ref += text_as_list[idx]
            else:
                try:
                    local_out += tokenizer.additional_special_tokens[local_masker_idx]
                    local_masker_idx += 1
                except IndexError:
                    continue
        masker_idxs.append(local_masker_idx)
        outs.append(local_out)
        out_refs.append(local_out_ref)

    model.cuda()
    batch_size = 50
    if len(outs) > batch_size:
        input_ids = tokenizer(outs, return_tensors="pt", padding=True, max_length=512, truncation=True)
        model_suggestions = []
        for idx in range(0, len(input_ids.input_ids), batch_size):
            local_inputs = {}
            for key, value in input_ids.items():
                local_inputs[key] = value[idx: idx+batch_size]
            for key, value in local_inputs.items():
                local_inputs[key] = value.cuda()
            outputs = model.generate(**local_inputs)
            model_suggestions.extend(tokenizer.batch_decode(outputs, skip_special_tokens=False))
    else:
        input_ids = tokenizer(outs, return_tensors="pt", padding=True)
        for key, value in input_ids.items():
            input_ids[key] = value.cuda()
        outputs = model.generate(**input_ids)
        model_suggestions = tokenizer.batch_decode(outputs, skip_special_tokens=False)

    inversed_data = []
    for idx, suggestion in enumerate(model_suggestions):
        local_out = outs[idx]
        local_masker_idx = masker_idxs[idx]
        present_tokens = [tokenizer.additional_special_tokens[idx] for idx in range(local_masker_idx) if
                          tokenizer.additional_special_tokens[idx] in suggestion]
        for idx, present in enumerate(present_tokens):
            if idx == len(present_tokens) - 1:
                index = suggestion.find(present)
                start_idx = index + len(present)
                local_out = local_out.replace(present, suggestion[start_idx:])
            else:
                base_index = suggestion.find(present)
                start_idx = base_index + len(present)
                upper_index = suggestion.find(present_tokens[idx + 1])
                local_out = local_out.replace(present, suggestion[start_idx:upper_index])
        for item in tokenizer.additional_special_tokens:
            local_out = local_out.replace(item, "")
        inversed_data.append(local_out)
    return inversed_data