kanishkamisra / minicons

Utility for behavioral and representational analyses of Language Models
https://minicons.kanishka.website
MIT License
117 stars 29 forks source link

Implementing "within_word_l2r" for conditional score #42

Closed plonerma closed 10 months ago

plonerma commented 10 months ago

Currently it is not possible to use the withing_word_l2r strategy in the conditional_score of MaskedLMScorer. This PR fixes this by using a masking function which is shared between prepare_text and prime_text (and additionally sets up the possibility of using non-masked suffixes in the MLM scorer).

This allows the following usage:

from minicons import scorer

model = scorer.MaskedLMScorer('distilbert-base-cased', 'cpu')

prefixes = [
    "The traveler lost",
    "The traveler lost",
]

stimuli = [
    "the souvenir.",
    "interest."
]

complete_sentences = [f"{p} {s}" for p, s in zip(prefixes, stimuli)]

def reduction(t):
    return t.sum().item()

for PLL_metric in ("original", "within_word_l2r"):
    print("---", PLL_metric, "---")

    print("Individual tokens:")

    for sentence in model.token_score(complete_sentences, PLL_metric=PLL_metric):
        print(" ".join((f"{t} ({s})" for t, s in sentence)))

    print("Complete sequence:", model.sequence_score(complete_sentences, PLL_metric=PLL_metric, reduction=reduction))
    print("Conditional:", model.conditional_score(prefix=prefixes, stimuli=stimuli, PLL_metric=PLL_metric, reduction=reduction))

    print("\n")

With the output:

--- original ---
Individual tokens:
The (-3.2606148719787598) travel (-3.4615001678466797) ##er (-3.8147225379943848) lost (-8.155838012695312) the (-2.182821273803711) so (-0.026716232299804688) ##uve (-5.7220458984375e-05) ##nir (0.0) . (-1.5065603256225586)
The (-3.3534998893737793) travel (-3.7995529174804688) ##er (-5.0453338623046875) lost (-4.697760581970215) interest (-4.439081192016602) . (-0.7187442779541016)
Complete sequence: [-22.408830642700195, -22.053972721099854]
Conditional: [-3.7161550521850586, -5.157825469970703]

--- within_word_l2r ---
Individual tokens:
The (-3.2606148719787598) travel (-8.343323707580566) ##er (-3.8147225379943848) lost (-8.155838012695312) the (-2.182821273803711) so (-8.663061141967773) ##uve (-3.5395898818969727) ##nir (0.0) . (-1.5065603256225586)
The (-3.3534998893737793) travel (-9.00644302368164) ##er (-5.0453338623046875) lost (-4.697760581970215) interest (-4.439081192016602) . (-0.7187442779541016)
Complete sequence: [-39.46653175354004, -27.260862827301025]
Conditional: [-15.892032623291016, -5.157825469970703]
netlify[bot] commented 10 months ago

Deploy Preview for pyminicons processing.

Name Link
Latest commit d80d8444b57a848f6b581d67f137b71392eb4d7e
Latest deploy log https://app.netlify.com/sites/pyminicons/deploys/654a58112b57c600099f1365
netlify[bot] commented 10 months ago

Deploy Preview for pyminicons canceled.

Name Link
Latest commit a541824f922c7d1d45288e3aac842f9726ae3e86
Latest deploy log https://app.netlify.com/sites/pyminicons/deploys/654df56773ef6400085bee1e
kanishkamisra commented 10 months ago

Oops I should have probably merged this before I made a new change to scorer.py -- do you mind integrating my new changes and re-submitting? This is awesome btw, thank you so much!!

plonerma commented 10 months ago

No problem. I will adapt it probably tomorrow. Thanks for developing the framework and accepting the change!

kanishkamisra commented 10 months ago

great -- sorry again for not first merging this!

plonerma commented 10 months ago

Hey, I merged your master into my branch and additionally added the suffix-option for conditional MLM-scores (I hope that's in your interest).

Example usage:

from typing import List
from minicons import scorer

model = scorer.MaskedLMScorer('distilbert-base-cased', None)

prefixes = [
    "The traveler lost",
    "The traveler lost",
]

stimuli = [
    "the souvenir",
    "interest"
]

suffixes = [
    "at the market.",
    "at the market."
]

complete_sentences: List[str] = [f"{pre} {stim} {suff}" for pre, stim, suff in zip(prefixes, stimuli, suffixes)]

def reduction(t):
    return t.sum().item()

for PLL_metric in ("original", "within_word_l2r"):
    print("---", PLL_metric, "---")

    print("Individual tokens:")

    for sentence in model.token_score(complete_sentences, PLL_metric=PLL_metric):
        print(" ".join((f"{t} ({s})" for t, s in sentence)))

    print("Complete sequence:", model.sequence_score(complete_sentences, PLL_metric=PLL_metric, reduction=reduction))

    print("Conditional:", model.conditional_score(prefix=prefixes, stimuli=stimuli, suffix=suffixes, PLL_metric=PLL_metric, reduction=reduction))

    print("\n") 

Produces:

--- original ---
Individual tokens:
The (-2.931204319000244) travel (-3.1608409881591797) ##er (-4.340202808380127) lost (-10.719362258911133) the (-2.783437728881836) so (-0.018465042114257812) ##uve (-2.09808349609375e-05) ##nir (0.0) at (-2.0171499252319336) the (-1.7253851890563965) market (-5.643357276916504) . (-0.3891754150390625)
The (-3.215424060821533) travel (-4.790759563446045) ##er (-5.153533935546875) lost (-5.166162490844727) interest (-3.1110496520996094) at (-3.688335418701172) the (-1.3834552764892578) market (-6.61713171005249) . (-0.44433021545410156)
Complete sequence: [-33.728601932525635, -33.57018232345581]
Conditional: [-2.8019237518310547, -3.1110496520996094]

--- within_word_l2r ---
Individual tokens:
The (-2.931204319000244) travel (-8.166111946105957) ##er (-4.340202808380127) lost (-10.719362258911133) the (-2.783437728881836) so (-8.323075294494629) ##uve (-2.5555038452148438) ##nir (0.0) at (-2.0171499252319336) the (-1.7253851890563965) market (-5.643357276916504) . (-0.3891754150390625)
The (-3.215424060821533) travel (-9.80713939666748) ##er (-5.153533935546875) lost (-5.166162490844727) interest (-3.1110496520996094) at (-3.688335418701172) the (-1.3834552764892578) market (-6.61713171005249) . (-0.44433021545410156)
Complete sequence: [-49.593966007232666, -38.586562156677246]
Conditional: [-13.662016868591309, -3.1110496520996094]