chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
526 stars 75 forks source link

Adding phi model #167

Closed pevnak closed 5 months ago

pevnak commented 10 months ago

Hello all.

I have started to work on adding microsoft phi models to Transformers, but I got totally stuck with the rotary positional encoding and the self-attention. I cannot get my head around it. @chengchingwen Would you have a little bit of time for help?

Thanks a lot.

Tomas

My code so far is

using ..Layers
using ..Layers: CompositeEmbedding, SelfAttention
using ChainRulesCore
using Functors
using Static

using NeuralAttentionlib
using NeuralAttentionlib: WithScore

abstract type HGFPhiPreTrainedModel <: HGFPreTrainedModel end

struct HGFPhiModel{E, DEC} <: HGFPhiPreTrainedModel
    embed::E
    decoder::DEC
end
@functor HGFPhiModel

(model::HGFPhiModel)(nt::NamedTuple) = model.decoder(model.embed(nt))

for T in :[
    HGFPhiForCausalLM,
    # HGFPhiForSequenceClassification,
].args
    @eval begin
        @hgfdefmodel $T HGFPhiPreTrainedModel
    end
end

basemodelkey(::Type{<:HGFPhiPreTrainedModel}) = :model
isbasemodel(::Type{<:HGFPhiModel}) = true
isbasemodel(::Type{<:HGFPhiPreTrainedModel}) = false

get_model_type(::Val{:phi}) = (
    model = HGFPhiModel,
    forcausallm = HGFPhiForCausalLM,
)

function load_model(_type::Type{HGFPhiPreTrainedModel}, cfg, state_dict, prefix)
    embed = load_model(HGFLlamaModel, CompositeEmbedding, cfg, state_dict, prefix)
    decoder = load_model(_type, TransformerBlock, cfg, state_dict, prefix)
    return HGFPhiModel(embed, decoder)
end

function load_model(_type::Type{HGFPhiForCausalLM}, cfg, state_dict, prefix)
    model = load_model(HGFPhiPreTrainedModel, cfg, state_dict, joinname(prefix, "model"))
    if cfg[:tie_word_embeddings]
        embedding = model.embed.token.embeddings
    else
        vocab_size, dims, factor = cfg[:vocab_size], cfg[:hidden_size], Float32(cfg[:initializer_range])
        embedding = getweight(weight_init(vocab_size, dims, factor), Layers.Embed,
                              state_dict, joinname(prefix, "lm_head.weight"))
    end
    lmhead = Layers.EmbedDecoder(Layers.Embed(embedding))
    return HGFLlamaForCausalLM(model, Layers.Branch{(:logit,), (:hidden_state,)}(lmhead))
end

function load_model(_type::Type{<:HGFPhiForCausalLM}, ::Type{<:SelfAttention}, cfg, state_dict, prefix)
    head, dims = cfg[:num_attention_heads], cfg[:hidden_size]
    @assert dims % head == 0 "The hidden size is not a multiple of the number of attention heads."
    head_dims = div(dims, head)
    kv_head = something(cfg[:num_key_value_heads], head)
    grouped_attn = kv_head != head
    @assert head % kv_head == 0 "The number of query is not dividable by the number of key/value groups"
    return_score = cfg[:output_attentions]
    factor = Float32(cfg[:initializer_range])
    @assert isnothing(cfg[:rope_scaling]) "Scaling Rotary Embedding is not support yet"
    @show (head_dims, head, kv_head)
    q_weight = getweight(weight_init(dims, dims, factor), Array,
                         state_dict, joinname(prefix, "q_proj.weight"))
    q_bias = getweight(weight_init(dims, dims, factor), Array,
                         state_dict, joinname(prefix, "q_proj.bias"))
    k_weight = getweight(weight_init(dims, kv_head * head_dims, factor), Array,
                         state_dict, joinname(prefix, "k_proj.weight"))
    k_bias = getweight(weight_init(dims, kv_head * head_dims, factor), Array,
                         state_dict, joinname(prefix, "k_proj.bias"))
    v_weight = getweight(weight_init(dims, kv_head * head_dims, factor), Array,
                         state_dict, joinname(prefix, "v_proj.weight"))
    v_bias = getweight(weight_init(dims, kv_head * head_dims, factor), Array,
                         state_dict, joinname(prefix, "v_proj.bias"))
    o_weight = getweight(weight_init(dims, dims, factor), Array, state_dict, joinname(prefix, "dense.weight"))
    o_bias = getweight(weight_init(dims, dims), Array, state_dict, joinname(prefix, "dense.bias"))
    qkv_proj = Layers.Fork(Layers.Dense(q_weight, q_bias), Layers.Dense(k_weight, k_bias), Layers.Dense(v_weight, v_bias))
    o_proj = Layers.Dense(o_weight, o_bias)
    if grouped_attn
        op = CausalLlamaRoPEGroupedQueryAttenOp(head, kv_head)
    else
        op = CausalGPTNeoXRoPEMultiheadQKVAttenOp(head_dims, head)
    end
    return_score && (op = WithScore(op))
    return SelfAttention(op, qkv_proj, o_proj)
end

function load_model(::Type{<:HGFPhiPreTrainedModel}, ::Type{<:Layers.LayerNorm}, cfg, state_dict, prefix)
    dims = cfg[:hidden_size]
    ln_ϵ = Float32(cfg[:layer_norm_eps])
    ln_weight = getweight(one_init(dims), Array, state_dict, joinname(prefix, "weight"))
    ln_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "bias"))
    return Layers.LayerNorm(ln_weight, ln_bias, ln_ϵ)
end

function load_model(_type::Type{<:HGFPhiPreTrainedModel}, ::Type{<:TransformerBlock}, cfg, state_dict, prefix)
    n = cfg[:num_hidden_layers]
    collect_output = cfg[:output_attentions] || cfg[:output_hidden_states]
    blocks = []
    for i = 1:n
        lprefix = joinname(prefix, :layers, i-1)

        ln = load_model(HGFPhiPreTrainedModel, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "input_layernorm"))
        sa = load_model(HGFLlamaPreTrainedModel, SelfAttention, cfg, state_dict, joinname(lprefix, "self_attn"))
        ff = load_model(HGFLlamaPreTrainedModel, Layers.Chain{Tuple{Layers.Dense, Layers.Dense}}, cfg, state_dict, joinname(lprefix, "mlp"))
        sa = Layers.PreNormResidual(sa, 

        block = TransformerBlock(sa, ff)
        push!(blocks, block)
    end
    collect_f = collect_output ? Layers.collect_outputs : nothing
    trf = Transformer(Tuple(blocks), collect_f)
    final_ln = load_model(HGFPhiPreTrainedModel, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "norm"))
    return Layers.Chain(trf, final_ln)
end
chengchingwen commented 10 months ago

Sure, I'll be glad to help. Can you elaborate more on the problem? I'm not sure what the question is.

pevnak commented 9 months ago

Hi Peter,

Thanks a lot for answer. I made a progress, but I got stuck that I cannot reproduce the effect of self-attention with the rotary embedding. I have written a code in Julia and Python to compare results side by side.

The python code is as follows

import torch
import transformers
import math
torch.manual_seed(0)

def rotate_half(x):
   """Rotates half the hidden dims of the input."""
   x1 = x[..., : x.shape[-1] // 2]
   x2 = x[..., x.shape[-1] // 2 :]
   return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
   cos = cos[position_ids].unsqueeze(unsqueeze_dim)
   sin = sin[position_ids].unsqueeze(unsqueeze_dim)
   q_embed = (q * cos) + (rotate_half(q) * sin)
   k_embed = (k * cos) + (rotate_half(k) * sin)
   return q_embed, k_embed

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
   cos = cos[position_ids].unsqueeze(unsqueeze_dim)
   sin = sin[position_ids].unsqueeze(unsqueeze_dim)
   q_embed = (q * cos)
   k_embed = (k * cos)
   return q_embed, k_embed

# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi

def repeat_kv(hidden_states, n_rep):
   """
   This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
   num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
   """
   batch, num_key_value_heads, slen, head_dim = hidden_states.shape
   if n_rep == 1:
      return hidden_states
   hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
   return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

model =transformers.AutoModelForCausalLM.from_pretrained('microsoft/phi-1', torch_dtype=torch.float32, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained('microsoft/phi-1', trust_remote_code=True)

s = 'Tell me something about Julia?'
inputs = tokenizer(s, return_tensors='pt', return_attention_mask=False)
torch.save(inputs, '/tmp/inputs.torch')
e = model.model.embed_tokens(inputs.input_ids)
l = model.model.layers[0]
torch.save(e, '/tmp/embedding.torch')
hidden_states = l.input_layernorm(e)
torch.save(hidden_states, '/tmp/hidden_states.torch')

attn_outputs, self_attn_weights, present_key_value = l.self_attn(hidden_states)
torch.save(attn_outputs, '/tmp/attn_outputs.torch')

sa = l.self_attn
query_states = sa.q_proj(hidden_states)
torch.save(query_states, '/tmp/query_states.torch')
key_states = sa.k_proj(hidden_states)
torch.save(key_states, '/tmp/key_states.torch')
value_states = sa.v_proj(hidden_states)
torch.save(value_states, '/tmp/value_states.torch')

bsz, q_len, _ = hidden_states.size()
query_states = query_states.view(bsz, q_len, sa.num_heads, sa.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, sa.num_key_value_heads, sa.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, sa.num_key_value_heads, sa.head_dim).transpose(1, 2)

cos, sin = sa.rotary_emb(value_states, seq_len=kv_seq_len)
torch.save(cos, '/tmp/cos.torch')
torch.save(sin, '/tmp/sin.torch')

# Partial rotary embedding
query_rot, query_pass = (
   query_states[..., : sa.rotary_emb.dim],
   query_states[..., sa.rotary_emb.dim :],
)

torch.save(query_rot, '/tmp/query_rot.torch')
torch.save(query_pass, '/tmp/query_pass.torch')

key_rot, key_pass = (
   key_states[..., : sa.rotary_emb.dim],
   key_states[..., sa.rotary_emb.dim :],
)

torch.save(key_rot, '/tmp/key_rot.torch')
torch.save(key_pass, '/tmp/key_pass.torch')

# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot_pos, key_rot_pos = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, None)
torch.save(query_rot_pos, '/tmp/query_rot_pos.torch')
torch.save(key_rot_pos, '/tmp/key_rot_pos.torch')

# [batch_size, seq_length, num_heads, head_dim]
query_rot_states = torch.cat((query_rot_pos, query_pass), dim=-1)
key_rot_states = torch.cat((key_rot_pos, key_pass), dim=-1)

torch.save(query_rot_states, '/tmp/query_rot_states.torch')
torch.save(key_rot_states, '/tmp/key_rot_states.torch')

key_states = repeat_kv(key_states, sa.num_key_value_groups)
value_states = repeat_kv(value_states, sa.num_key_value_groups)

# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
key_rot_states.size()
query_rot_states.size()
attn_weights = torch.matmul(
   query_rot_states.to(torch.float32), key_rot_states.to(torch.float32).transpose(2, 3)
)/ math.sqrt(sa.head_dim)
attn_weights.size()

torch.save(attn_weights, '/tmp/attn_weights.torch')

# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
torch.save(attn_weights, '/tmp/attn_weights.torch')
attn_weights = torch.nn.functional.dropout(attn_weights, p=sa.attention_dropout, training=sa.training)

attn_output = torch.matmul(attn_weights, value_states)
torch.save(attn_output, '/tmp/attn_output.torch')

Where after every operation I save the results, such that I can load them to Julia and compare. The similar Julia code is as

using Transformers
using Transformers.Flux
using Transformers.HuggingFace
using Transformers.HuggingFace: HGFPhiPreTrainedModel, HGFPhiForCausalLM,  HGFLlamaPreTrainedModel, SelfAttention
using Transformers.HuggingFace: joinname, load_model
using Transformers.Layers: apply_on_namedtuple
using Transformers.HuggingFace: weighted_sum_mixing, gptneox_rope_multihead_qkv_attention, gptneox_rope_attention_score, generic_multihead_qkv_attention, gptneox_reorder
import Transformers.HuggingFace: one_init, zero_init, getweight
using Transformers.Layers: LayerNorm
using NeuralAttentionlib: as_collapsed, _split_and_move_head,  generic_qkv_attention, mixing, attention_score, split_head, naive_qkv_attention, naive_attention_score, scaled_dot_product_score
using TextEncodeBase
using Statistics
using StatsBase
using Pickle
using Flux
using NNlib

function load_torch_matrix(filename)
    x = Pickle.Torch.THload(filename)
    x = Matrix(transpose(x[1,:,:]))
end

"""

 x: [bs, num_attention_heads, seq_len, head_size]
"""
function compare_tensors(x, filename)
    r = Pickle.Torch.THload(filename)
    size(r,1) !=1 && error("the first dimension should be one (one sample)")
    r = r[1,:,:,:]
    size(x,1) == size(r,3) || error("dimension mismatch")
    size(x,3) == size(r,2) || error("dimension mismatch")
    size(x,2) == size(r,1) || error("dimension mismatch")
    maximum(maximum(abs.(r[:,i,:] .- transpose(x[:,:,i]))) for i in 1:size(x,3))
end

# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
function  phi_rotary_embedding(dim, max_position_embeddings=2048, base=10000)
    inv_freq = 1 ./ (base .^ (collect(0:2:(dim-1)) ./ dim))
    inv_freq = vcat(inv_freq,inv_freq)
    t = 0:(max_position_embeddings-1)
    (sin.(inv_freq .* t'), cos.(inv_freq .* t'))
end

"""
    rotate_half(x)

    Rotates half the hidden dims of the input
"""
function rotate_half(x)
    d = size(x,1) ÷ 2
    x1 = @view x[1:d,:,:]
    x2 = @view x[d+1:end, :, :]
    cat(-x2, x1, dims = 1)
end

function apply_rotary_pos_emb(q, k, _cos::AbstractMatrix, _sin::AbstractMatrix)
    _sin = reshape(_sin, size(_sin,1), 1, size(_sin,2))
    _cos = reshape(_cos, size(_cos,1), 1, size(_cos,2))
    q_embed = q .* _cos .+ rotate_half(q) .* _sin
    k_embed = k .* _cos .+ rotate_half(k) .* _sin
    return (q_embed, k_embed)
end

model_type = model_name = "microsoft/phi-1"
cfg = Transformers.HuggingFace.load_config(model_type)
state_dict = Transformers.HuggingFace.load_state_dict(model_type; config=cfg)
state_dict = Dict(filter((kv) -> contains(kv[1], "model.layers.0."), collect(state_dict)))

# textenc = Transformers.HuggingFace.load_tokenizer(model_name)
# model = Transformers.HuggingFace.load_model(Transformers.HuggingFace.HGFPhiForCausalLM, cfg, state_dict, "")

s = "Tell me something about Julia?"

# input = encode(textenc, s).token 
# input = OneHotArray(OneHot{0x0000c477}.([ 24447, 503, 1224, 547, 22301, 31]))
# input_ref = Pickle.Torch.THload("/tmp/inputs.torch")["input_ids"]
# e = model.model.embed(input) # verify embedding
e_ref = load_torch_matrix("/tmp/embedding.torch")
e = e_ref

lprefix = "model.layers.0"
# residual = e.hidden_state
residual = e
ln = load_model(HGFPhiPreTrainedModel, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "input_layernorm"))
hidden_state = ln(residual)
hidden_state_ref = load_torch_matrix("/tmp/hidden_states.torch")

# this is where we want to do the self-attention. But it does not work, so 
# we need to learn, how to use it
# sa = load_model(HGFPhiForCausalLM, SelfAttention, cfg, state_dict, joinname(lprefix, "self_attn"))
# attn_outputs = sa((;hidden_state = hidden_state_ref)).hidden_state
# attn_outputs .- load_torch_matrix("/tmp/attn_outputs.torch")

nt = (;hidden_state = hidden_state_ref)
qkv = apply_on_namedtuple(sa.qkv_proj, nt)
maximum(abs.(qkv.hidden_state[1] .- load_torch_matrix("/tmp/query_states.torch")))
maximum(abs.(qkv.hidden_state[2] .- load_torch_matrix("/tmp/key_states.torch")))
maximum(abs.(qkv.hidden_state[3] .- load_torch_matrix("/tmp/value_states.torch")))

# this part is about piercing the computation of attention scode
base, dim, head = 10000.0, 64, 32
hidden_size = 32
len = 6
_sincos = phi_rotary_embedding(32)
_sin = _sincos[1][:,1:len]
_cos = _sincos[2][:,1:len]
maximum(_sin .- Pickle.Torch.THload("/tmp/sin.torch")')
maximum(_cos .- Pickle.Torch.THload("/tmp/cos.torch")')

q,k,v = qkv.hidden_state
query_states = _split_and_move_head(head, q)
key_states = _split_and_move_head(head, k)
hv = _split_and_move_head(head, v)

query_rot, query_pass = (
   query_states[1:32,:, :], # sa.rotary_emb.dim = 32
   query_states[33:end, :, :],
)

compare_tensors(query_rot, "/tmp/query_rot.torch")
compare_tensors(query_pass, "/tmp/query_pass.torch")

key_rot, key_pass = (
   key_states[1:32,:, :],   # sa.rotary_emb.dim = 32
   key_states[33:end, :, :],
)

compare_tensors(key_rot, "/tmp/key_rot.torch")
compare_tensors(key_pass, "/tmp/key_pass.torch")

query_rot_pos, key_rot_pos = apply_rotary_pos_emb(query_rot, key_rot, _cos, _sin)

compare_tensors(query_rot_pos, "/tmp/query_rot_pos.torch")
compare_tensors(key_rot_pos, "/tmp/key_rot_pos.torch")

query_rot_states = cat(query_rot_pos, query_pass, dims=1)
key_rot_states = cat(key_rot_pos, key_pass, dims=1)

compare_tensors(query_rot_states, "/tmp/query_rot_states.torch")
compare_tensors(key_rot_states, "/tmp/key_rot_states.torch")

# attn_weights = attention_score(naive_attention_score(), as_collapsed(query_rot_states), as_collapsed(key_rot_states))

attn_weights = scaled_dot_product_score(query_rot_states, key_rot_states);
compare_tensors(attn_weights, "/tmp/attn_weights.torch")

It is quite a lot of code, but I had to write to remove the effect of individual layers in NeuralattentionLib to better understand, what is going on. All the intermediate results are similar to the python version, except theattn_weights`, where I got completely different dimensions, as I have discussed in discourse forum. I do not know if it is because NeuralattentionLib does the self-attention very differently, or I need to permute dimensions. I have tried a naive permutation of dimentions and it did not work out.

Any help would be appreciated.

chengchingwen commented 9 months ago

This:

q,k,v = qkv.hidden_state
query_states = _split_and_move_head(head, q)
key_states = _split_and_move_head(head, k)
hv = _split_and_move_head(head, v)

is mistaking the length dimension for the batch dimension because the real batch size is omitted and the semantics of each dimension are not specified and thus give the wrong results.

it should either be:

# usually preferable
query_states = _split_and_move_head(head, as_collapsed(q))
key_states = _split_and_move_head(head, as_collapsed(k))
hv = _split_and_move_head(head, as_collapsed(v))

or:

query_states = _split_and_move_head(head, reshape(q, Val(3)))
key_states = _split_and_move_head(head, reshape(k, Val(3)))
hv = _split_and_move_head(head, reshape(v, Val(3)))

It's worth noticing that in most (probably all) Python implementations, the dimension of the tensor is fixed. On the other hand, NeuralAttentionlib makes an abstraction layer (CollapsedDimsArray) above those tensors. The attention "algorithm" requires the tensor to have 3 dimensions: the feature dimension, length dimension, and batch dimension. CollapsedDimsArray groups the dimensions of the tensor into these 3. The whole attention interface in NeuralAttentionlib is built on top of this abstraction layer.

pevnak commented 9 months ago

I do not understand this yet, but I finally make it through the attention, such that it now gives the same results as the one in pytorch. I will create the PR, where bits of what I have done are present. But I think it would require adding the right version of rotary embedding.

I do not know, my code is worth a PR. It is a mess.

I have it here https://github.com/pevnak/Transformers.jl#tp/phi

The important part is in test/debug, where I execute the first layer of phi in python and in julia to see, how to code the self-attention. I have prepared a bit of implementation/phi, but it lacks the construction with the correct self-attention and wiring.

chengchingwen commented 9 months ago

I add it in #168. Let me know if you have tested it. Once we don't find any problem, then I'll merge it.