google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
300 stars 36 forks source link

Warnings When Restoring the Params #1110

Open richardmkit opened 2 months ago

richardmkit commented 2 months ago

The code shows like the following. It could run but prompted some warning: /opt/conda/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1544: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with. warnings.warn(

Model params are a dictionary type tree like this: {'params': {'Dense_0': {'bias': Array([ 1.75125599e-02, 4.10381891e-02, 1.96171561e-04, -2.42875870e-02, 1.51480837e-02, 3.71114984e-02, -4.87685064e-03, -1.63130835e-02, 5.47729768e-02, -5.70005644e-03, -5.12132980e-03, 3.10970427e-05, 2.31470224e-02, -1.55021911e-02, 1.72994770e-02, 2.26450190e-02, -3.05333477e-03, 9.84513387e-03, -3.31428014e-02, 3.80380601e-02, -3.20659392e-03, -3.09392507e-03, 1.86821781e-02, -1.25538018e-02, 4.41285521e-02, -4.72985283e-02, 3.28246184e-04, 1.31683890e-02, 1.32193940e-03, 1.48607325e-02, -3.43988538e-02, 8.36286321e-03, -2.90089939e-02, -3.98164280e-02, 2.31531989e-02, 3.27519067e-02, 2.72216517e-02, -2.89463606e-02, 2.44598440e-03, 6.63389359e-03, 4.59096301e-03, -1.23022813e-02, 1.29767824e-02, 4.81516495e-03, -1.20902760e-02, -2.27207374e-02, -1.27110668e-02, 1.20020472e-02, 3.91368084e-02, -4.30837227e-03, 3.32566164e-02, -2.71463916e-02, 2.25058272e-02, -3.91818397e-03, 1.49554424e-02, -6.85477350e-03, 1.01907691e-03, -6.12435490e-02, -1.18386028e-02, -6.03230670e-03, 7.54657155e-03, 8.14247876e-03, -6.61915401e-03, 8.85959063e-03], dtype=float32), 'kernel': Array([[-0.08260956, 0.42094436, -0.27531517, ..., -0.09135673, 0.21974503, 0.21818572], [-0.04729075, -0.2666923 , 0.14365157, ..., 0.13939556, -0.16218886, -0.04071451], [ 0.10921595, 0.01364996, -0.11194808, ..., -0.01299416, -0.02805288, -0.0272818 ], ..., [-0.04990593, -0.01473087, 0.06877133, ..., -0.05618783, -0.06337533, -0.17277789], [-0.10326906, -0.03525492, 0.21592571, ..., -0.06726424, 0.04024971, 0.21430357], [-0.06426816, 0.01593289, 0.01053577, ..., -0.08965493, 0.1562466 , 0.19774263]], dtype=float32)}, 'Dense_1': {'bias': Array([ 0.01895721, 0.02381056, -0.00297396, 0.00253655, -0.00579324, -0.00917996, -0.0524504 , -0.01307405, -0.00445831, -0.01765897, -0.02990872, -0.01783756, -0.00417391, -0.02153626, -0.01237699, 0.00332377], dtype=float32), 'kernel': Array([[-0.0687123 , 0.11527583, 0.02760898, ..., -0.11483309, 0.09793864, 0.24956086], [-0.17475414, 0.06557149, 0.02568068, ..., -0.18699066, -0.235098 , 0.17345282], [ 0.21747173, -0.00923413, -0.04049944, ..., 0.04021717, -0.03704283, 0.13622351], ..., [ 0.1976054 , -0.07143398, 0.11763132, ..., 0.15076494, -0.08623252, 0.08628309], [ 0.142208 , -0.07710048, 0.05116218, ..., 0.05643938, 0.01690205, -0.00337057], [-0.08983981, -0.08721507, 0.05885444, ..., 0.2054291 , -0.0595689 , 0.09482205]], dtype=float32)}, 'Dense_2': {'bias': Array([-0.27287585, -0.31808662, -0.22906446, -0.2392324 , -0.1169002 , -0.45564348, -0.27986547, -0.4403381 , -0.3194529 , -0.03579619, -0.27706683, -0.20705369, -0.3464241 , -0.16313383, -0.3245753 , -0.12070157, -0.10058393, -0.335585 , -0.23487404, -0.13635263, -0.3551262 , -0.19502614, -0.27066055, -0.22264665, -0.17983833, -0.38362965, -0.2549991 , -0.35028023, -0.02632488, -0.24093926, -0.26272595, -0.32823324, -0.1442327 , -0.18271838, -0.3466661 , -0.2975728 , -0.2519938 , -0.20744751, -0.48289314, -0.20181467, -0.0694458 , -0.2868131 , -0.0621618 , -0.1489881 , -0.22316173, -0.26048866, -0.3741152 , -0.22691546, -0.28160262, -0.39583412, -0.44518995, -0.26774997, -0.18526609, -0.3136557 , -0.29002288, -0.2983223 , -0.4889701 , -0.20518056, -0.06886528, -0.18853416, -0.06637306, -0.45197925, -0.3145519 , -0.23673685], dtype=float32), 'kernel': Array([[ 0.01163595, 0.02972907, 0.02223774, ..., 0.0228426 , 0.06626749, 0.04824122], [-0.00046695, -0.08048075, -0.10955726, ..., -0.02934793, -0.04758933, -0.06418303], [-0.1286408 , -0.1255539 , -0.11133979, ..., -0.02911641, -0.16321352, -0.1160882 ], ..., [-0.01398295, -0.02459321, 0.21012494, ..., 0.07538891, -0.08655173, 0.02649066], [ 0.11309541, 0.09003462, -0.01682626, ..., -0.18835074, 0.09409627, 0.05982505], [-0.12019534, -0.06023936, 0.14683168, ..., -0.10527591, -0.0902904 , -0.08336279]], dtype=float32)}, 'Dense_3': {'bias': Array([-0.19127552, -0.1539499 , -0.10825736, -0.1273831 , -0.14423408, -0.10800537, -0.1158509 , -0.19560331, -0.0544809 , -0.12320589, -0.14327425, -0.06410812, -0.06359567, -0.05706155, -0.16820268, -0.06965973], dtype=float32), 'kernel': Array([[-0.17942922, -0.2944296 , -0.15672816, ..., -0.22876067, 0.04339508, -0.00558091], [ 0.00910171, 0.00975822, -0.06065388, ..., -0.14715518, -0.05254569, -0.09955268], [-0.1705753 , -0.0021669 , 0.120933 , ..., -0.01333852, -0.09636445, -0.13689627], ..., [-0.07949128, -0.12297069, -0.30489385, ..., 0.08778288, -0.10832835, -0.20170009], [-0.00643041, -0.06448855, -0.02339345, ..., -0.1243507 , 0.10904145, 0.01637404], [-0.15585563, -0.07519744, -0.05542706, ..., -0.0285616 , -0.03688109, -0.07460079]], dtype=float32)}}}

# Save the params
if os.path.exists('/user/working/model_params'):
    shutil.rmtree('/user/working/model_params')

checkpoint=orbax.checkpoint.PyTreeCheckpointer()
checkpoint.save('/user/working/model_params',params)

checkpoint=orbax.checkpoint.PyTreeCheckpointer()
checkpoint.restore('/user/working/model_params')

How should I fix this? Thx.

cpgaffney1 commented 2 months ago

This documentation should be useful: https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html#checkpointing-pytrees-of-arrays