erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
168 stars 19 forks source link

TypeError: __call__() takes from 2 to 9 positional arguments but 10 were given #31

Closed JinSeoungwoo closed 9 months ago

JinSeoungwoo commented 9 months ago

Below is the code I used for making train_data

train_data = dataset.map(
        lambda x:tokenizer(generate_prompt(x),max_length=4096,padding='max_length',add_special_tokens=False),
        remove_columns=dataset.column_names,
    )

Error :

  File "/usr/local/lib/python3.8/dist-packages/EasyDel/trainer/fsdp_train.py", line 409, in train
    sharded_train_state_, loss, accuracy = self.sharded_train_step_fn(sharded_train_state_,
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/trainer/fsdp_train.py", line 312, in fsdp_train_step_
    (loss__, accuracy__), grad = grad_fn(state.params)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/trainer/fsdp_train.py", line 303, in calculate_loss
    logits = state.apply_fn(params=params, **batch,
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 565, in __call__
    outputs = self.module.apply(
TypeError: __call__() takes from 2 to 9 positional arguments but 10 were given

and also there is a o_proj error in mistral

erfanzar commented 9 months ago

make sure that generate_prompt(x) have to return string

JinSeoungwoo commented 9 months ago

make sure that generate_prompt(x) have to return string

generate_prompt return string

def generate_prompt(data_point):
    full_prompt = prompter.generate_prompt(
        data_point["instruction"],
        data_point["input"],
        data_point["output"],
    )
    return full_prompt

I think

outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            None,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rng_s,
            mutable=mutable,
        )

this code from modelling_mistral_flax.py is problem

JinSeoungwoo commented 9 months ago

make sure that generate_prompt(x) have to return string

generate_prompt return string

def generate_prompt(data_point):
    full_prompt = prompter.generate_prompt(
        data_point["instruction"],
        data_point["input"],
        data_point["output"],
    )
    return full_prompt

I think

outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            None,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rng_s,
            mutable=mutable,
        )

this code from modelling_mistral_flax.py is problem

removed None and it looks working. but Initializer expected to generate shape (1024, 4096) but got shape (4096, 1024) instead for parameter "kernel" in "/model/layers/remat(0)/self_attn/k_proj"

erfanzar commented 9 months ago

ill fix that in next commit

erfanzar commented 9 months ago

Fixed <3 please remove the code you have edited clone the repo or install it with

pip install git+https://github.com/erfanzar/EasyDeL

and transform weights again