huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.67k stars 933 forks source link

BERT Safetensors variable mismatch #1887

Open Christof23 opened 7 months ago

Christof23 commented 7 months ago

Hi, I was running the BERT example code and noticed that some of the variables weren't correctly aligning with the current Safetensors obtained via:

let repo: ApiRepo = api.model("bert-base-uncased".to_string());
let weights_path: PathBuf = repo.get("model.safetensors")?

For example the model spec in candle-transformers/src/models/bert.rs results in: Error: TensorNotFound("embeddings.word_embeddings.weight").

The Safetensors version prepends all variables with bert and uses the older gamma/beta notation. This issue has also been noted here.

I think the problem is in layer_norm which doesn't expect gamma and beta but weight and bias:

let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?;
let bias = if config.affine {
    Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?)
} else {
    None
};

The Safetensor variables are as follows:

bert.embeddings.LayerNorm.beta
bert.embeddings.LayerNorm.gamma
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.word_embeddings.weight
bert.encoder.layer.0.attention.output.LayerNorm.beta
bert.encoder.layer.0.attention.output.LayerNorm.gamma
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.output.LayerNorm.beta
bert.encoder.layer.0.output.LayerNorm.gamma
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.1.attention.output.LayerNorm.beta
bert.encoder.layer.1.attention.output.LayerNorm.gamma
bert.encoder.layer.1.attention.output.dense.bias
bert.encoder.layer.1.attention.output.dense.weight
bert.encoder.layer.1.attention.self.key.bias
bert.encoder.layer.1.attention.self.key.weight
bert.encoder.layer.1.attention.self.query.bias
bert.encoder.layer.1.attention.self.query.weight
bert.encoder.layer.1.attention.self.value.bias
bert.encoder.layer.1.attention.self.value.weight
bert.encoder.layer.1.intermediate.dense.bias
bert.encoder.layer.1.intermediate.dense.weight
bert.encoder.layer.1.output.LayerNorm.beta
bert.encoder.layer.1.output.LayerNorm.gamma
bert.encoder.layer.1.output.dense.bias
bert.encoder.layer.1.output.dense.weight
bert.encoder.layer.10.attention.output.LayerNorm.beta
bert.encoder.layer.10.attention.output.LayerNorm.gamma
bert.encoder.layer.10.attention.output.dense.bias
bert.encoder.layer.10.attention.output.dense.weight
bert.encoder.layer.10.attention.self.key.bias
bert.encoder.layer.10.attention.self.key.weight
bert.encoder.layer.10.attention.self.query.bias
bert.encoder.layer.10.attention.self.query.weight
bert.encoder.layer.10.attention.self.value.bias
bert.encoder.layer.10.attention.self.value.weight
bert.encoder.layer.10.intermediate.dense.bias
bert.encoder.layer.10.intermediate.dense.weight
bert.encoder.layer.10.output.LayerNorm.beta
bert.encoder.layer.10.output.LayerNorm.gamma
bert.encoder.layer.10.output.dense.bias
bert.encoder.layer.10.output.dense.weight
bert.encoder.layer.11.attention.output.LayerNorm.beta
bert.encoder.layer.11.attention.output.LayerNorm.gamma
bert.encoder.layer.11.attention.output.dense.bias
bert.encoder.layer.11.attention.output.dense.weight
bert.encoder.layer.11.attention.self.key.bias
bert.encoder.layer.11.attention.self.key.weight
bert.encoder.layer.11.attention.self.query.bias
bert.encoder.layer.11.attention.self.query.weight
bert.encoder.layer.11.attention.self.value.bias
bert.encoder.layer.11.attention.self.value.weight
bert.encoder.layer.11.intermediate.dense.bias
bert.encoder.layer.11.intermediate.dense.weight
bert.encoder.layer.11.output.LayerNorm.beta
bert.encoder.layer.11.output.LayerNorm.gamma
bert.encoder.layer.11.output.dense.bias
bert.encoder.layer.11.output.dense.weight
bert.encoder.layer.2.attention.output.LayerNorm.beta
bert.encoder.layer.2.attention.output.LayerNorm.gamma
bert.encoder.layer.2.attention.output.dense.bias
bert.encoder.layer.2.attention.output.dense.weight
bert.encoder.layer.2.attention.self.key.bias
bert.encoder.layer.2.attention.self.key.weight
bert.encoder.layer.2.attention.self.query.bias
bert.encoder.layer.2.attention.self.query.weight
bert.encoder.layer.2.attention.self.value.bias
bert.encoder.layer.2.attention.self.value.weight
bert.encoder.layer.2.intermediate.dense.bias
bert.encoder.layer.2.intermediate.dense.weight
bert.encoder.layer.2.output.LayerNorm.beta
bert.encoder.layer.2.output.LayerNorm.gamma
bert.encoder.layer.2.output.dense.bias
bert.encoder.layer.2.output.dense.weight
bert.encoder.layer.3.attention.output.LayerNorm.beta
bert.encoder.layer.3.attention.output.LayerNorm.gamma
bert.encoder.layer.3.attention.output.dense.bias
bert.encoder.layer.3.attention.output.dense.weight
bert.encoder.layer.3.attention.self.key.bias
bert.encoder.layer.3.attention.self.key.weight
bert.encoder.layer.3.attention.self.query.bias
bert.encoder.layer.3.attention.self.query.weight
bert.encoder.layer.3.attention.self.value.bias
bert.encoder.layer.3.attention.self.value.weight
bert.encoder.layer.3.intermediate.dense.bias
bert.encoder.layer.3.intermediate.dense.weight
bert.encoder.layer.3.output.LayerNorm.beta
bert.encoder.layer.3.output.LayerNorm.gamma
bert.encoder.layer.3.output.dense.bias
bert.encoder.layer.3.output.dense.weight
bert.encoder.layer.4.attention.output.LayerNorm.beta
bert.encoder.layer.4.attention.output.LayerNorm.gamma
bert.encoder.layer.4.attention.output.dense.bias
bert.encoder.layer.4.attention.output.dense.weight
bert.encoder.layer.4.attention.self.key.bias
bert.encoder.layer.4.attention.self.key.weight
bert.encoder.layer.4.attention.self.query.bias
bert.encoder.layer.4.attention.self.query.weight
bert.encoder.layer.4.attention.self.value.bias
bert.encoder.layer.4.attention.self.value.weight
bert.encoder.layer.4.intermediate.dense.bias
bert.encoder.layer.4.intermediate.dense.weight
bert.encoder.layer.4.output.LayerNorm.beta
bert.encoder.layer.4.output.LayerNorm.gamma
bert.encoder.layer.4.output.dense.bias
bert.encoder.layer.4.output.dense.weight
bert.encoder.layer.5.attention.output.LayerNorm.beta
bert.encoder.layer.5.attention.output.LayerNorm.gamma
bert.encoder.layer.5.attention.output.dense.bias
bert.encoder.layer.5.attention.output.dense.weight
bert.encoder.layer.5.attention.self.key.bias
bert.encoder.layer.5.attention.self.key.weight
bert.encoder.layer.5.attention.self.query.bias
bert.encoder.layer.5.attention.self.query.weight
bert.encoder.layer.5.attention.self.value.bias
bert.encoder.layer.5.attention.self.value.weight
bert.encoder.layer.5.intermediate.dense.bias
bert.encoder.layer.5.intermediate.dense.weight
bert.encoder.layer.5.output.LayerNorm.beta
bert.encoder.layer.5.output.LayerNorm.gamma
bert.encoder.layer.5.output.dense.bias
bert.encoder.layer.5.output.dense.weight
bert.encoder.layer.6.attention.output.LayerNorm.beta
bert.encoder.layer.6.attention.output.LayerNorm.gamma
bert.encoder.layer.6.attention.output.dense.bias
bert.encoder.layer.6.attention.output.dense.weight
bert.encoder.layer.6.attention.self.key.bias
bert.encoder.layer.6.attention.self.key.weight
bert.encoder.layer.6.attention.self.query.bias
bert.encoder.layer.6.attention.self.query.weight
bert.encoder.layer.6.attention.self.value.bias
bert.encoder.layer.6.attention.self.value.weight
bert.encoder.layer.6.intermediate.dense.bias
bert.encoder.layer.6.intermediate.dense.weight
bert.encoder.layer.6.output.LayerNorm.beta
bert.encoder.layer.6.output.LayerNorm.gamma
bert.encoder.layer.6.output.dense.bias
bert.encoder.layer.6.output.dense.weight
bert.encoder.layer.7.attention.output.LayerNorm.beta
bert.encoder.layer.7.attention.output.LayerNorm.gamma
bert.encoder.layer.7.attention.output.dense.bias
bert.encoder.layer.7.attention.output.dense.weight
bert.encoder.layer.7.attention.self.key.bias
bert.encoder.layer.7.attention.self.key.weight
bert.encoder.layer.7.attention.self.query.bias
bert.encoder.layer.7.attention.self.query.weight
bert.encoder.layer.7.attention.self.value.bias
bert.encoder.layer.7.attention.self.value.weight
bert.encoder.layer.7.intermediate.dense.bias
bert.encoder.layer.7.intermediate.dense.weight
bert.encoder.layer.7.output.LayerNorm.beta
bert.encoder.layer.7.output.LayerNorm.gamma
bert.encoder.layer.7.output.dense.bias
bert.encoder.layer.7.output.dense.weight
bert.encoder.layer.8.attention.output.LayerNorm.beta
bert.encoder.layer.8.attention.output.LayerNorm.gamma
bert.encoder.layer.8.attention.output.dense.bias
bert.encoder.layer.8.attention.output.dense.weight
bert.encoder.layer.8.attention.self.key.bias
bert.encoder.layer.8.attention.self.key.weight
bert.encoder.layer.8.attention.self.query.bias
bert.encoder.layer.8.attention.self.query.weight
bert.encoder.layer.8.attention.self.value.bias
bert.encoder.layer.8.attention.self.value.weight
bert.encoder.layer.8.intermediate.dense.bias
bert.encoder.layer.8.intermediate.dense.weight
bert.encoder.layer.8.output.LayerNorm.beta
bert.encoder.layer.8.output.LayerNorm.gamma
bert.encoder.layer.8.output.dense.bias
bert.encoder.layer.8.output.dense.weight
bert.encoder.layer.9.attention.output.LayerNorm.beta
bert.encoder.layer.9.attention.output.LayerNorm.gamma
bert.encoder.layer.9.attention.output.dense.bias
bert.encoder.layer.9.attention.output.dense.weight
bert.encoder.layer.9.attention.self.key.bias
bert.encoder.layer.9.attention.self.key.weight
bert.encoder.layer.9.attention.self.query.bias
bert.encoder.layer.9.attention.self.query.weight
bert.encoder.layer.9.attention.self.value.bias
bert.encoder.layer.9.attention.self.value.weight
bert.encoder.layer.9.intermediate.dense.bias
bert.encoder.layer.9.intermediate.dense.weight
bert.encoder.layer.9.output.LayerNorm.beta
bert.encoder.layer.9.output.LayerNorm.gamma
bert.encoder.layer.9.output.dense.bias
bert.encoder.layer.9.output.dense.weight
bert.pooler.dense.bias
bert.pooler.dense.weight
cls.predictions.bias
cls.predictions.transform.LayerNorm.beta
cls.predictions.transform.LayerNorm.gamma
cls.predictions.transform.dense.bias
cls.predictions.transform.dense.weight
cls.seq_relationship.bias
cls.seq_relationship.weight
Christof23 commented 7 months ago

This PR addresses the above issue but not sure if updates to layer_norm are appropriate https://github.com/huggingface/candle/pull/1888

vrama628 commented 6 months ago

I'm running into the same issue -- I followed the instructions in the Candle reference guide to see how to run a HuggingFace model in Candle and I was surprised to see that the steps they recommend (loading the bert-base-uncased model into the BertModel struct) result in an error.