Closed s-smits closed 3 weeks ago
hi update your jax version to 0.4.28
After doing that it installs correctly, but freezes when importing the easydel libraires:
Loaded pretrained model: ssmits/Falcon2-5.5B-Dutch
Input shape: (1, 4096)
Attention partitions: {'query_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'key_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'value_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp'), 'bias_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, None), 'attention_partition_spec': PartitionSpec(('dp', 'fsdp'), 'sp', None, 'tp')}
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[2], line 32
22 attention_partitions = dict(
23 query_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
24 key_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
(...)
27 attention_partition_spec=PartitionSpec(("dp","fsdp"), "sp", None, "tp"),
28 )
30 print(f"Attention partitions: {attention_partitions}")
---> 32 model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
33 pretrained_model_name_or_path,
34 device=jax.devices('cpu')[0],
35 input_shape=input_shape,
36 device_map="auto",
37 sharding_axis_dims=sharding_axis_dims,
38 config_kwargs=dict(
39 use_scan_mlp=False,
40 attn_mechanism=attn_mechanism,
41 **attention_partitions
42 ),
43 **attention_partitions
44 )
46 print(f"Loaded model with params shape: {jax.tree_util.tree_map(lambda x: x.shape, params)}")
48 config = model.config
NameError: name 'AutoEasyDeLModelForCausalLM' is not defined
Should I make a separate issue for this?
no it's fine take a look at this
https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example
Yes, thank you, it's working!
Describe the bug It would be great to keep support for TPU-v3's on Kaggle. After 0.0.66 I get this error:
To Reproduce