EricLBuehler / candle-lora

Low rank adaptation (LoRA) for Candle.
MIT License
130 stars 14 forks source link

Bert model doesn't seem to instantiate with lora weights #21

Open jcrist1 opened 3 months ago

jcrist1 commented 3 months ago

I tried to instantiate a bert model with the following code:

use candle_core::DType;
use candle_lora::LoraConfig;
use candle_lora_transformers::bert::{BertModel, Config};
use candle_nn::{VarBuilder, VarMap};

fn main() {
    let config = "config.json";
    let device: candle_core::Device = candle_core::Device::Cpu;

    let config_str = std::fs::read_to_string(config).expect("Failed to load config");
    let config: Config = serde_json::from_str(&config_str).expect("failed to parse config");

    let map = VarMap::new();

    let builder = VarBuilder::from_varmap(&map, DType::F32, &device);

    let lora_config = LoraConfig::new(32, 1.0, Some(0.1));
    BertModel::load(builder, &config, false, lora_config)
        .expect("Failed to instantiate bert model");

    let data = map.data().lock().expect("Failed to lock var map data");
    for (key, tensor) in &*data {
        println!("{key}: {:?}", tensor.shape())
    }
}

cargo manifest

[package]
name = "candle-lora-test"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-lora = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
candle-lora-macro = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
candle-lora-transformers = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
serde_json = "1.0.127"

and model config

{
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.8.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

Which outputs

embeddings.position_embeddings.lora_embed.a0.weight: [32, 512]
embeddings.LayerNorm.weight: [384]
encoder.layer.2.attention.self.key.bias: [384]
encoder.layer.1.attention.output.dense.bias: [384]
encoder.layer.1.output.LayerNorm.weight: [384]
encoder.layer.4.attention.self.value.weight: [384, 384]
encoder.layer.2.intermediate.dense.bias: [1536]
encoder.layer.4.attention.self.query.bias: [384]
encoder.layer.0.output.LayerNorm.bias: [384]
embeddings.word_embeddings.lora_embed.a0.weight: [32, 30522]
encoder.layer.2.intermediate.dense.weight: [1536, 384]
encoder.layer.0.attention.output.dense.weight: [384, 384]
encoder.layer.3.attention.output.LayerNorm.weight: [384]
encoder.layer.3.output.dense.weight: [384, 1536]
encoder.layer.3.attention.output.dense.weight: [384, 384]
encoder.layer.3.output.LayerNorm.bias: [384]
encoder.layer.4.output.dense.bias: [384]
encoder.layer.5.attention.output.dense.weight: [384, 384]
encoder.layer.3.attention.output.LayerNorm.bias: [384]
encoder.layer.0.attention.output.dense.bias: [384]
encoder.layer.5.output.dense.weight: [384, 1536]
encoder.layer.4.attention.self.key.bias: [384]
encoder.layer.2.attention.output.dense.bias: [384]
encoder.layer.1.attention.self.key.bias: [384]
encoder.layer.0.attention.output.LayerNorm.weight: [384]
encoder.layer.4.intermediate.dense.weight: [1536, 384]
encoder.layer.0.attention.self.value.weight: [384, 384]
encoder.layer.3.attention.self.key.bias: [384]
encoder.layer.3.attention.output.dense.bias: [384]
embeddings.token_type_embeddings.weight: [2, 384]
encoder.layer.4.attention.output.LayerNorm.bias: [384]
encoder.layer.1.output.LayerNorm.bias: [384]
encoder.layer.5.attention.self.value.weight: [384, 384]
encoder.layer.2.attention.output.dense.weight: [384, 384]
encoder.layer.5.attention.self.key.bias: [384]
encoder.layer.5.output.LayerNorm.weight: [384]
encoder.layer.2.output.LayerNorm.weight: [384]
encoder.layer.0.attention.self.key.bias: [384]
encoder.layer.0.output.dense.weight: [384, 1536]
encoder.layer.1.output.dense.bias: [384]
encoder.layer.2.output.dense.weight: [384, 1536]
embeddings.word_embeddings.weight: [30522, 384]
encoder.layer.0.attention.self.query.weight: [384, 384]
encoder.layer.2.attention.output.LayerNorm.weight: [384]
encoder.layer.0.intermediate.dense.bias: [1536]
encoder.layer.2.attention.output.LayerNorm.bias: [384]
encoder.layer.1.attention.self.query.bias: [384]
encoder.layer.4.attention.self.key.weight: [384, 384]
encoder.layer.4.attention.output.dense.bias: [384]
embeddings.position_embeddings.weight: [512, 384]
embeddings.token_type_embeddings.lora_embed.a0.weight: [32, 2]
encoder.layer.1.intermediate.dense.weight: [1536, 384]
encoder.layer.1.attention.self.query.weight: [384, 384]
encoder.layer.1.attention.self.value.weight: [384, 384]
embeddings.position_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.4.output.dense.weight: [384, 1536]
encoder.layer.5.attention.output.LayerNorm.weight: [384]
encoder.layer.5.output.dense.bias: [384]
encoder.layer.0.attention.output.LayerNorm.bias: [384]
encoder.layer.2.output.dense.bias: [384]
embeddings.word_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.1.intermediate.dense.bias: [1536]
encoder.layer.2.attention.self.value.weight: [384, 384]
encoder.layer.2.attention.self.query.bias: [384]
encoder.layer.1.attention.output.LayerNorm.bias: [384]
encoder.layer.1.attention.self.value.bias: [384]
encoder.layer.2.output.LayerNorm.bias: [384]
encoder.layer.3.attention.self.query.weight: [384, 384]
encoder.layer.3.attention.self.value.bias: [384]
encoder.layer.3.attention.self.key.weight: [384, 384]
encoder.layer.1.attention.output.LayerNorm.weight: [384]
encoder.layer.1.attention.output.dense.weight: [384, 384]
embeddings.LayerNorm.bias: [384]
encoder.layer.3.attention.self.value.weight: [384, 384]
encoder.layer.3.intermediate.dense.weight: [1536, 384]
encoder.layer.3.output.dense.bias: [384]
encoder.layer.0.attention.self.key.weight: [384, 384]
encoder.layer.4.attention.self.value.bias: [384]
encoder.layer.3.intermediate.dense.bias: [1536]
encoder.layer.4.attention.output.dense.weight: [384, 384]
encoder.layer.5.attention.self.query.bias: [384]
encoder.layer.5.attention.output.dense.bias: [384]
encoder.layer.5.output.LayerNorm.bias: [384]
encoder.layer.4.attention.output.LayerNorm.weight: [384]
encoder.layer.2.attention.self.query.weight: [384, 384]
encoder.layer.5.attention.self.query.weight: [384, 384]
encoder.layer.5.attention.output.LayerNorm.bias: [384]
encoder.layer.4.intermediate.dense.bias: [1536]
encoder.layer.0.attention.self.value.bias: [384]
encoder.layer.0.output.LayerNorm.weight: [384]
encoder.layer.3.attention.self.query.bias: [384]
encoder.layer.0.intermediate.dense.weight: [1536, 384]
encoder.layer.4.attention.self.query.weight: [384, 384]
encoder.layer.4.output.LayerNorm.weight: [384]
encoder.layer.0.attention.self.query.bias: [384]
encoder.layer.5.intermediate.dense.weight: [1536, 384]
encoder.layer.1.output.dense.weight: [384, 1536]
encoder.layer.4.output.LayerNorm.bias: [384]
encoder.layer.0.output.dense.bias: [384]
encoder.layer.5.attention.self.value.bias: [384]
encoder.layer.2.attention.self.value.bias: [384]
embeddings.token_type_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.2.attention.self.key.weight: [384, 384]
encoder.layer.1.attention.self.key.weight: [384, 384]
encoder.layer.3.output.LayerNorm.weight: [384]
encoder.layer.5.attention.self.key.weight: [384, 384]
encoder.layer.5.intermediate.dense.bias: [1536]

Importantly it doesn't seem to create lora weights for any of the encoder layers, only the embedding layers. I looked at the expanded code, and noticed that the generated constructor for a linear layer looks like this

the lora linear layer
```rust
    impl BertLinear {
        pub fn new(
            vb: VarBuilder,
            weight: Tensor,
            bias: Option<Tensor>,
            merge: bool,
            lora_config: LoraConfig,
        ) -> Self {
            let span = {
                use ::tracing::__macro_support::Callsite as _;
                static __CALLSITE: ::tracing::callsite::DefaultCallsite = {
                    static META: ::tracing::Metadata<'static> = {
                        ::tracing_core::metadata::Metadata::new(
                            "linear",
                            "candle_lora_transformers::bert",
                            tracing::Level::TRACE,
                            ::core::option::Option::Some(
                                "candle-lora-transformers/src/bert.rs",
                            ),
                            ::core::option::Option::Some(58u32),
                            ::core::option::Option::Some(
                                "candle_lora_transformers::bert",
                            ),
                            ::tracing_core::field::FieldSet::new(
                                &[],
                                ::tracing_core::callsite::Identifier(&__CALLSITE),
                            ),
                            ::tracing::metadata::Kind::SPAN,
                        )
                    };
                    ::tracing::callsite::DefaultCallsite::new(&META)
                };
                let mut interest = ::tracing::subscriber::Interest::never();
                if tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
                    && tracing::Level::TRACE
                        <= ::tracing::level_filters::LevelFilter::current()
                    && {
                        interest = __CALLSITE.interest();
                        !interest.is_never()
                    }
                    && ::tracing::__macro_support::__is_enabled(
                        __CALLSITE.metadata(),
                        interest,
                    )
                {
                    let meta = __CALLSITE.metadata();
                    ::tracing::Span::new(meta, &{ meta.fields().value_set(&[]) })
                } else {
                    let span = ::tracing::__macro_support::__disabled_span(
                        __CALLSITE.metadata(),
                    );
                    {};
                    span
                }
            };
            let dims = weight.dims2().unwrap();
            let linear_config = LoraLinearConfig::new(dims.1, dims.0);
            let mut this = Self {
                inner: Arc::new(Linear::new(weight, bias)),
                span,
            };
            if merge {
                this.get_merged_lora_model(
                    lora_config,
                    &vb.pp("lora_linear"),
                    Some(linear_config),
                    None,
                    None,
                    None,
                )
            } else {
                this.get_lora_model(
                    lora_config,
                    &vb.pp("lora_linear"),
                    Some(linear_config),
                    None,
                    None,
                    None,
                )
            }
            this
        }

But when I dig into this.get_lora_model I noticed that it doesn't actually use the self parameter

    impl BertLinear {
        /// Be sure to provide a configuration for each type!
        pub fn get_lora_model<'a>(
            &'a mut self,
            lora_config: candle_lora::LoraConfig,
            vb: &candle_nn::VarBuilder,
            linear_config: Option<candle_lora::LoraLinearConfig>,
            conv1d_config: Option<candle_lora::LoraConv1dConfig>,
            conv2d_config: Option<candle_lora::LoraConv2dConfig>,
            embed_config: Option<candle_lora::LoraEmbeddingConfig>,
        ) {
            let mut linear: ::std::collections::HashMap<
                String,
                &dyn candle_lora::LinearLayerLike,
            > = ::std::collections::HashMap::new();
            let mut conv1d: ::std::collections::HashMap<
                String,
                &dyn candle_lora::Conv1dLayerLike,
            > = ::std::collections::HashMap::new();
            let mut conv2d: ::std::collections::HashMap<
                String,
                &dyn candle_lora::Conv2dLayerLike,
            > = ::std::collections::HashMap::new();
            let mut embed: ::std::collections::HashMap<
                String,
                &dyn candle_lora::EmbeddingLayerLike,
            > = ::std::collections::HashMap::new();
            let mut embed: ::std::collections::HashMap<
                String,
                &dyn candle_lora::EmbeddingLayerLike,
            > = ::std::collections::HashMap::new();
            if !linear.is_empty() && linear_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for linear layers."),
                    );
                };
            }
            if !conv1d.is_empty() && conv1d_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for conv1d layers."),
                    );
                };
            }
            if !conv2d.is_empty() && conv2d_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for conv2d layers."),
                    );
                };
            }
            if !embed.is_empty() && embed_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for embedding layers."),
                    );
                };
            }
            let mut builder = candle_lora::SelectedLayersBuilder::new();
            if linear_config.is_some() {
                builder = builder.add_linear_layers(linear, linear_config.unwrap());
            }
            if conv1d_config.is_some() {
                builder = builder.add_conv1d_layers(conv1d, conv1d_config.unwrap());
            }
            if conv2d_config.is_some() {
                builder = builder.add_conv2d_layers(conv2d, conv2d_config.unwrap());
            }
            if embed_config.is_some() {
                builder = builder.add_embed_layers(embed, embed_config.unwrap());
            }
            let selection = builder.build();
            let new_layers = candle_lora::Lora::convert_model(
                selection,
                lora_config,
                &vb,
            );
        }

For comparison the get_lora_model of the BertEmbeddings ends with

          // ...
          [
                (self
                    .inner = ::std::sync::Arc::new(
                    new_layers.embed.get("inner").unwrap().clone(),
                )),
            ];

It seems like the macro isn't quite expanding correctly. Could this be the case?

jcrist1 commented 3 months ago

@EricLBuehler I could have a look at this if you think it looks like an issue

EricLBuehler commented 3 months ago

@jcrist1 yeah, this seems like an issue. May require modifying the macro itself, and/or adding features to the construction of layers to ensure recursive LoRA initialization.