chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
521 stars 74 forks source link

[Question] Possible to retrieve layer-wise activations? #166

Open pat-alt opened 8 months ago

pat-alt commented 8 months ago

Thanks for the great package @chengchingwen 🙏🏽

I have a somewhat naive question that you might be able to help me with. For a project I'm currently working on I am trying run linear probes on layer activations. In particular, I'm trying to reproduce the following exercise from this paper:

image

I've naively tried to simply apply the Flux.activations() function with no luck. Here's an example:

using Flux
using Transformers
using Transformers.TextEncoders
using Transformers.HuggingFace

# Load model from HF 🤗:
tkr = hgf"mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis:tokenizer"
mod = hgf"mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis:ForSequenceClassification"
query = [
    "The economy is stagnant.",
    "Output has grown in H2.",
]
a = encode(tkr, query)
julia> Flux.activations(mod.model, a)
ERROR: 
──────────────────────────────────────────────────────────────── MethodError ───────────────────────────────────────────────────────────────
╭──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│      ╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮           │
│  (1) │  top-level scope                                                                                                      │           │
│      │  REPL[80]:1                                                                                                           │           │
│      ╰───────────────────────────────────────────────────────────────────────────────────────────────────────── TOP LEVEL ───╯           │
│                                                                                                                                          │
╰──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─────────────────────────────────────────────────────────────── MethodError ──────────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│  MethodError: no method matching activations(::Transformers.HuggingFace.HGFRobertaModel{Transformers.Layers.Chain{Tuple{Transformers.    │
│  Layers.CompositeEmbedding{Tuple{Transformers.Layers.WithArg{(:token,), Transformers.Layers.Embed}, Transformers.Layers.WithOptArg{(:    │
│  hidden_state,), (:position,), Transformers.Layers.ApplyEmbed{Base.Broadcast.BroadcastFunction{typeof(+)}, Transformers.Layers.FixedL    │
│  enPositionEmbed, NeuralAttentionlib.PrefixedFunction} , Transformers.Layers.WithOptArg{(:hidden_state,), (:segment,), Transformers.L    │
│  ayers.ApplyEmbed{Base.Broadcast.BroadcastFunction{typeof(+)}, Transformers.Layers.Embed, typeof(Transformers.HuggingFace.bert_ones_l    │
│  ike)} } , Transformers.Layers.DropoutLayer} , Transformer{NTuple 6, Transformers.Layers.PostNormTransformerBlock}, Nothing}    │
│  , Nothing}, ::@NamedTuple{token::OneHotArray 0x0000c459, 2, 3, Matrix{OneHot 0x0000c459} },       │
│  attention_mask::NeuralAttentionlib.RevLengthMask 1, Vector} )                                                        │
│                                                                                                                                          │
│  Closest candidates are:                                                                                                                 │
│    activations(!Matched::Flux.Chain, ::Any)                                                                                              │
│     @ Flux ~/.julia/packages/Flux/EHgZm/src/layers/basic.jl:102                                                                │
│                                                                                                                                          │
│                                                                                                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Any advice would be much appreciated!

chengchingwen commented 8 months ago

There is an output_hidden_states configuration that can be set up with HGFConfig:

model_name = "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis"
cfg = HuggingFace.HGFConfig(load_config(model_name); output_hidden_states = true)
mod = load_model(model_name, "ForSequenceClassification"; config = cfg)

then you can access all layer outputs with mod(a).outputs which is a NTuple{number_layers, @NamedTuple{hidden_state::Array{Float32, 3}}. Another similar configuration is output_attentions that would also include the attentions scores in the named tuples in .outputs.

BTW, if you don't need the sequence classification head, you can simply use load_model(model_name; config = cfg) which would extract the model part without the classification layers.

pat-alt commented 8 months ago

Amazing, thanks very much for the quick response 👍🏽

(I won't close this since you added the tag for documentation)

pat-alt commented 8 months ago

Small follow-up question: is it also somehow possible to collect outputs for each layer of the classifier head?

Edit: I realize I can just break down the forward pass into layer-by-layer calls as below, but perhaps there's a more streamline way to do this?

b = clf.layer.layers[1](b).hidden_state |>
        x -> clf.layer.layers[2](x)
chengchingwen commented 8 months ago

You can try extracting the actual layers in the classifier head and construct a Flux.Chain and call with Flux.activations. Otherwise, I think a manual loop/calls is probably the simplest.