stanfordnlp / GloVe

Software in C and data files for the popular GloVe model for distributed word representations, a.k.a. word vectors or embeddings
Apache License 2.0
6.86k stars 1.51k forks source link

Clarification on the diagrams in the readme #139

Open ixxie opened 5 years ago

ixxie commented 5 years ago

I am trying to reproduce the geometry diagrams on the front page and was hoping for more details on how they are generated. I have tried the following procedure:

Am I missing something? There is a chance this is due to the particular mixture of models I am using (I stack GloVe with some other models).

AngledLuffa commented 4 years ago

No one here worked on these diagrams, but that is a reasonable approach. Is there an issue?

ixxie commented 4 years ago

Since then I made this for a private project; feel free to use it if it helps:

import torch
import matplotlib.pyplot as plt
import numpy as np
import re

head = {'head_width': 0.1, 'head_length': 0.2}
solid = {'linestyle': 'solid'}
dashed = {'linestyle': 'dashed'}
dotted = {'linestyle': 'dotted'}

arrow_styles = {
  '->': {**solid, **head},
  '-': {**solid},
  '-->': {**dashed, **head},
  '--': {**dashed},
  '..': {**dotted},
  '..>': {**dotted, **head}
}

arrow_symbols = list(arrow_styles.keys())
escaped_arrow_symbols = [re.escape(symbol) for symbol in arrow_symbols]

def normalize_tensors(tensors):

    center = mean_tensor(tensors)

    return [tensor - center for tensor in tensors]

def PCA(tensors, k=2):

    # preprocessing
    tensor = torch.squeeze(torch.stack(tensors))
    mean = torch.mean(tensor, 0)
    tensor = tensor - mean.expand_as(tensor)

    # SVD
    U, S, V = torch.svd(torch.t(tensor), some=False)
    tensor = torch.mm(tensor, U[:, :k])

    # post processing
    tensor = torch.squeeze(tensor)
    pc_tensors = torch.split(tensor, 1, 0)
    pc_tensors = [torch.squeeze(tensor) for tensor in pc_tensors]

    return pc_tensors

def lex_veclang(text):

    patterns = "|".join(["{.*?}"] + escaped_arrow_symbols)
    matches = re.findall(patterns, text)
    tokens = [re.sub('[{}]', '', match) for match in matches]

    return tokens

def parse_veclang(text):

    tokens = lex_veclang(text)

    text = [token for token in tokens if token not in arrow_symbols]
    arrows = []

    arrow_instances = [(index, token) for index, token in enumerate(tokens)
                       if token in arrow_symbols]
    for index, token in arrow_instances:
        first = text.index(tokens[index-1])
        second = text.index(tokens[index+1])
        arrows.append((first, second, token))

    return text, arrows

def plot_word_vectors(texts, arrows):

    lines = [text.to_plain_string() for text in texts]
    vectors = [text.get_embedding().data.numpy() for text in texts]
    filename = './store/plots/fig.png'

    fig, ax = plt.subplots()
    factor = 2
    fig.figsize = (factor*6.4, factor*4.8)

    x = [vec[0] for vec in vectors]
    y = [vec[1] for vec in vectors]

    ax.scatter(x, y, color='white')

    gap = 0.5
    for arrow in arrows:
        i, j, style = arrow
        delta = vectors[j] - vectors[i]
        delta_unit = delta/np.linalg.norm(delta)
        base = vectors[i] + gap*delta_unit
        diff = delta - 2*gap*delta_unit
        plt.arrow(base[0], base[1], diff[0], diff[1],
                  color='#3a3a3a',
                  length_includes_head=True, antialiased=True,
                  **arrow_styles[style])

    printed = []
    for i, line in enumerate(lines):
        if line not in printed:
            ax.annotate(line, (x[i], y[i]), ha='center', va='center')
            printed.append(line)

    plt.axis('off')
    plt.savefig(filename)

    return filename

def embedding_plot(model, text):

    # parsing
    lines, arrows = parse_veclang(text)
    sentences = [model.parse(line) for line in lines]
    tensors = [s.get_embedding() for s in sentences]

    # tensor processing
    norm_tensors = normalize_tensors(tensors)
    flat_tensors = PCA(norm_tensors)

    # plot plot
    filename = plot_word_vectors(lines, flat_tensors, arrows)

    return f'Plotted embeddings to {filename}'
AngledLuffa commented 4 years ago

Thanks, much appreciated! Do you want us to include it in our distributions? We would need to have the license for it in that case.

On Fri, Jan 24, 2020, 1:57 PM Matan Shenhav notifications@github.com wrote:

Since then I made this for a private project; feel free to use it if it helps:

import torch import matplotlib.pyplot as plt import numpy as np import re

head = {'head_width': 0.1, 'head_length': 0.2} solid = {'linestyle': 'solid'} dashed = {'linestyle': 'dashed'} dotted = {'linestyle': 'dotted'}

arrow_styles = { '->': {solid, head}, '-': {solid}, '-->': {dashed, head}, '--': {dashed}, '..': {dotted}, '..>': {dotted, **head} }

arrow_symbols = list(arrow_styles.keys()) escaped_arrow_symbols = [re.escape(symbol) for symbol in arrow_symbols]

def normalize_tensors(tensors):

center = mean_tensor(tensors)

return [tensor - center for tensor in tensors]

def PCA(tensors, k=2):

# preprocessing
tensor = torch.squeeze(torch.stack(tensors))
mean = torch.mean(tensor, 0)
tensor = tensor - mean.expand_as(tensor)

# SVD
U, S, V = torch.svd(torch.t(tensor), some=False)
tensor = torch.mm(tensor, U[:, :k])

# post processing
tensor = torch.squeeze(tensor)
pc_tensors = torch.split(tensor, 1, 0)
pc_tensors = [torch.squeeze(tensor) for tensor in pc_tensors]

return pc_tensors

def lex_veclang(text):

patterns = "|".join(["{.*?}"] + escaped_arrow_symbols)
matches = re.findall(patterns, text)
tokens = [re.sub('[{}]', '', match) for match in matches]

return tokens

def parse_veclang(text):

tokens = lex_veclang(text)

text = [token for token in tokens if token not in arrow_symbols]
arrows = []

arrow_instances = [(index, token) for index, token in enumerate(tokens)
                   if token in arrow_symbols]
for index, token in arrow_instances:
    first = text.index(tokens[index-1])
    second = text.index(tokens[index+1])
    arrows.append((first, second, token))

return text, arrows

def plot_embeddings(texts, arrows):

lines = [text.to_plain_string() for text in texts]
vectors = [text.get_embedding().data.numpy() for text in texts]
filename = './store/plots/fig.png'

fig, ax = plt.subplots()
factor = 2
fig.figsize = (factor*6.4, factor*4.8)

x = [vec[0] for vec in vectors]
y = [vec[1] for vec in vectors]

ax.scatter(x, y, color='white')

gap = 0.5
for arrow in arrows:
    i, j, style = arrow
    delta = vectors[j] - vectors[i]
    delta_unit = delta/np.linalg.norm(delta)
    base = vectors[i] + gap*delta_unit
    diff = delta - 2*gap*delta_unit
    plt.arrow(base[0], base[1], diff[0], diff[1],
              color='#3a3a3a',
              length_includes_head=True, antialiased=True,
              **arrow_styles[style])

printed = []
for i, line in enumerate(lines):
    if line not in printed:
        ax.annotate(line, (x[i], y[i]), ha='center', va='center')
        printed.append(line)

plt.axis('off')
plt.savefig(filename)

return filename

def mindplot(model, text):

# parsing
lines, arrows = parse_veclang(text)
sentences = [model.parse(line) for line in lines]
tensors = [s.get_embedding() for s in sentences]

# tensor processing
norm_tensors = normalize_tensors(tensors)
flat_tensors = PCA(norm_tensors)

# plot plot
filename = plot_embeddings(lines, flat_tensors, arrows)

return f'Plotted embeddings to {filename}'

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/stanfordnlp/GloVe/issues/139?email_source=notifications&email_token=AA2AYWMDEWUOV5VGVMIB2O3Q7NP3VA5CNFSM4G2KST62YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEJ4G3UQ#issuecomment-578317778, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA2AYWLR2GEU25ML3O6LDIDQ7NP3VANCNFSM4G2KST6Q .

ixxie commented 4 years ago

For your convenience @AngledLuffa, I hereby license the above code under the Apache License version 2.0, just like the rest of the project.

AngledLuffa commented 4 years ago

Thank you! I'll merge it this week.

numpy also has an svd routine, which might be a bit slower without GPU support but would remove the dependency on torch. Does that sound like a reasonable tradeoff? I can make that adjustment myself; just checking that it sounds good

AngledLuffa commented 4 years ago

It looks like there is no specific main() method yet. Worth adding one?

ixxie commented 4 years ago

Use your judgement here; I don't know enough about your project's development strategy to have a say here :) feel free to do whatever!