Closed JamesPerlman closed 2 years ago
jax.local_device_count() result is discarded here, but I do not believe there is a way to reach this codepath (currently there are no usages of utils.shard(...) missing the device_count arg)
jax.local_device_count()
utils.shard(...)
device_count
Strange, not sure how that happened. Thanks for the fix!
jax.local_device_count()
result is discarded here, but I do not believe there is a way to reach this codepath (currently there are no usages ofutils.shard(...)
missing thedevice_count
arg)