PAIR-code / lit

The Learning Interpretability Tool: Interactively analyze ML models to understand their behavior in an extensible and framework agnostic interface.
https://pair-code.github.io/lit
Apache License 2.0
3.48k stars 355 forks source link

display attention bugs - issues with Chinese characters #28

Open bigprince97 opened 4 years ago

bigprince97 commented 4 years ago

when I display gpt2 or bert attentions, It's truncated and doesn't show the whole thing,how can i fix this?

image

bigprince97 commented 4 years ago

if input is English, is display rightly, but it can't display total chinese.

jameswex commented 4 years ago

Can you provide the link to your code / model / dataset so we can reproduce, if possible?

bigprince97 commented 4 years ago
class MyBertMLM(lit_model.Model):
  MASK_TOKEN = "[MASK]"
  @property
  def num_layers(self):
      return self.model.config.num_hidden_layers
  @property
  def max_seq_length(self):
    return self.model.config.max_position_embeddings
  def __init__(self, model_name="bert-base-chinese", top_k=10):
    super().__init__()
    self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    self.model = transformers.BertForMaskedLM.from_pretrained(
            model_name, output_hidden_states=True, output_attentions=True)
    self.top_k = top_k

  def _get_topk_tokens(self,
                       scores: np.ndarray) -> List[List[Tuple[str, float]]]:
    index_array = np.argpartition(scores, -self.top_k, axis=1)[:, -self.top_k:]
    top_tokens = [
        self.tokenizer.convert_ids_to_tokens(idxs) for idxs in index_array
    ]
    top_scores = np.take_along_axis(scores, index_array, axis=1)
    return [
        sorted(list(zip(toks, scores)), key=lambda ab: -ab[1])
        for toks, scores in zip(top_tokens, top_scores)
    ]
  def _postprocess(self, output: Dict[str, np.ndarray]):
    slicer = slice(1, output.pop("ntok") - 1)
    output["tokens"] = self.tokenizer.convert_ids_to_tokens(
        output.pop("input_ids")[slicer])
    probas = output.pop("probas")
    for i in range(len(range(self.num_layers))):
        output[f"layer_{i:d}_attention"] = output[f"layer_{i:d}_attention"][:, slicer, slicer]
    output["pred_tokens"] = self._get_topk_tokens(probas[slicer])
    for i, token in enumerate(output["tokens"]):
      if token != self.MASK_TOKEN:
        output["pred_tokens"][i] = []
    return output
  def max_minibatch_size(self, unused_config=None) -> int:
    return 8
  def predict_minibatch(self, inputs, config=None):
    tokenized_texts = [
        ex.get("tokens") or self.tokenizer.tokenize(ex["text"]) for ex in inputs
    ]
    encoded_input = self.tokenizer.batch_encode_plus(
        tokenized_texts,
        is_pretokenized=True,
        return_tensors="pt",
        add_special_tokens=True,
        max_length=self.max_seq_length,
        pad_to_max_length=True)
    max_tokens = torch.max(
        torch.sum(encoded_input["attention_mask"], dim=1))
    encoded_input = {k: v[:, :max_tokens] for k, v in encoded_input.items()}
    logits, embs, unused_attentions = self.model(**encoded_input)
    batched_outputs = {
        "probas": torch.softmax(logits, dim=-1).detach().numpy(),
        "input_ids": encoded_input["input_ids"].numpy(),
        "ntok": torch.sum(encoded_input["attention_mask"], dim=1).numpy(),
        "cls_emb": embs[-1][:, 0].detach().numpy(),  # last layer, first token
    }
    for i in range(len(unused_attentions)):
      batched_outputs[f"layer_{i:d}_attention"] = unused_attentions[i].detach().numpy()
    unbatched_outputs = utils.unbatch_preds(batched_outputs)
    return map(self._postprocess, unbatched_outputs)
  def input_spec(self):
    return {
        "text": lit_types.TextSegment(),
        "tokens": lit_types.Tokens(required=False),
    }
  def output_spec(self):
    spec = {
        "tokens": lit_types.Tokens(parent="text"),
        "pred_tokens": lit_types.TokenTopKPreds(align="tokens"),
        "cls_emb": lit_types.Embeddings(),
    }
    for i in range(self.num_layers):
      spec[f"layer_{i:d}_attention"] = lit_types.AttentionHeads(
          align=("tokens", "tokens"))
    return spec

I change the pretrained_lms.py,use pytorch chinese bert model,add attention to output spec.

if model_name.startswith("bert-"):
  models[model_name] = pretrained_lms.MyBertMLM(
    model_name_or_path, top_k=FLAGS.top_k)

in the pretrained_lm_demo.py, i use my model

image

image

it display well if input English , display not well if input Chinese, same as gpt2 model.

bigprince97 commented 4 years ago

and I attempt to change attention_module.ts, it don't work.

jameswex commented 4 years ago

We will reproduce this locally and work on a fix. Thanks for discovering the issue!

2020zyc commented 4 years ago

The same problem, any news? @bigprince97 @jameswex thanks

jameswex commented 4 years ago

Sorry for the lack of updates. The issue is that the attention_module rendering logic (https://github.com/PAIR-code/lit/blob/main/lit_nlp/client/modules/attention_module.ts#L109) assumes that due to the fixed width font that every char takes up the same fixed width in pixels, and places its lines based on that. But with chinese characters, the fixed width font renders them wider, so the math for placing the X position of the attention lines is wrong.

There are the correct number of attention lines, but they are squeezed into too small a space, and the text gets cut off incorrectly due to that, and the tokens that are shown don't line up with the lines they are meant for.

We'll work on fixing this. In the meantime, you could try changing the width setting on the line references above (and rebuild the client), and see if you can get the spacing to look correct for your use case. But we'll fix it so it works correctly regardless of language.

jameswex commented 4 years ago

To rebuild the client, see https://github.com/PAIR-code/lit/#download-and-installation, specifically the "yarn && yarn build" command.

2020zyc commented 4 years ago

To rebuild the client, see https://github.com/PAIR-code/lit/#download-and-installation, specifically the "yarn && yarn build" command.

I changed the width and rebuild successfully. The attention graph changed with the different width, but still unnormal.

Look forward to your revision. Thanks.

pratikchhapolika commented 2 years ago

is align=("tokens", "tokens")) changed to this: lit_types.AttentionHeads(align_in="tokens", align_out="tokens") ?