foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
116 stars 18 forks source link

Faulty type handling for 'weight' kwarg #31

Closed daviswer closed 4 months ago

daviswer commented 4 months ago

The current scheme for overriding the defaults in the training config file is to pass keyword arguments into the training script, and then updating the config with those new values. This scheme breaks down for weights due to lack of type handling: the weights field takes a string of comma-separated numbers. Passing --weights="1,2,3" to the training script, however, results in the tuple (1,2,3) rather than the required string, as python interprets this argument as a tuple, causing dataloader construction failures. We should either enforce typing on the training script args, or update the config and dataloader code expectations to match python's default arg handling.

daviswer commented 4 months ago

The datasets field avoids this issue due to the inclusion of / and = characters forcing the type to str, but for simpler subfolder names I think it'll run into this issue as well

lchu-ibm commented 4 months ago

good catch.

Fire (not Python) will interpret it as tuple.

This is a combo effect from the fact that:

  1. bash will strip out quotes. i.e. your "1,2,3," will become 1,2,3. otherwise Fire will maintain quoted string as string, instead of interpret it into other types.
  2. Fire will interpret non-quoted string to things like List, Tuple, Dict, etc. whenever applicable. so de-quoted 1,2,3 became (1,2,3)
lchu-ibm commented 4 months ago

@daviswer as a workaround, please use "'1,2,3'"as input so that:

  1. after bash strip out your "", you get '1,2,3'
  2. as '1,2,3' has quotes, Fire will parse it as regular string.

I think this should definitely work. And we can revisit this later when we have more bandwidth as this is low prior.

daviswer commented 4 months ago

Interestingly, "'1,2,3'" still gets typed to tuple (1,2,3), though in a simple testing script that just prints the argument type and value, this does cast as desired. So it seems there's an additional failure point somewhere.

lchu-ibm commented 4 months ago

that's interesting. then we need better fix and I will do some tests to make sure it covers well.

In the meantime, I guess the best workaround now is to hardcode in the config file (but luckily this is something usually not changed over different runs)

lchu-ibm commented 4 months ago

@daviswer I am thinking to, instead, change from parse_data_args -

datas = splitstrip(datas)

to

datas = splitstrip(datas) if isinstance(datas, str) else datas

same for weights.