eole-nlp / eole

Open language modeling toolkit based on PyTorch
https://eole-nlp.github.io/eole
MIT License
53 stars 11 forks source link

Issue in validation order with model level parameters propagation #19

Closed funboarder13920 closed 3 months ago

funboarder13920 commented 3 months ago

Hello,

The no _override_values will propagate model level parameters to encoder/decoder but the validation of the encoder/decoder will be performed during the propagation loop. Because all the model level parameters are not set yet, there could be a state with model level parameters and default parameters not compatible and triggering assertion errors

I am not sure how to solve that, there is no order in pydantic validation for now

For example this code will raise an assertion error:

from eole.config.models import EmbeddingsConfig, TransformerLMModelConfig

conf = TransformerLMModelConfig(
                layers=10,
                hidden_size=9,
                heads=3,
                transformer_ff=64*3,
                embeddings=EmbeddingsConfig(
                    src_word_vec_size=256,
                    tgt_word_vec_size=256,
                ),
                # src_word_vec_size=src_word_vec_size,
                # tgt_word_vec_size=tgt_word_vec_size,
                model_type="text",
                mlp_activation_fn="gelu",
                self_attn_type="scaled-dot",  # not sure if scaled-dot-flash is fine
                max_relative_positions=-1,
                heads_kv=1,
                parallel_residual=True,
                shared_layer_norm=True,
                add_qkvbias=False,
                add_ffnbias=False,
            )

Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/usr/local/lib/python3.10/dist-packages/pydantic/main.py", line 176, in __init__ self.__pydantic_validator__.validate_python(data, self_instance=self) pydantic_core._pydantic_core.ValidationError: 1 validation error for TransformerLMModelConfig Assertion failed, Transformer Model dimension 9 must be divisible by the number of heads 8 [type=assertion_error, input_value=TransformerLMDecoderConfi...False, lambda_align=0.0), input_type=TransformerLMDecoderConfig] For further information visit https://errors.pydantic.dev/2.7/v/assertion_error

francoishernandez commented 3 months ago

Hey there, That is a valid point, thanks for reporting. Tentative solution in #24. Probably not fully bulletproof yet.