googlecolab / colabtools

Python libraries for Google Colaboratory
Apache License 2.0
2.18k stars 716 forks source link

Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED using 0.1 drivers since 10/02/2023 #3405

Closed henk717 closed 2 months ago

henk717 commented 1 year ago

Describe the current behavior When running an older version of JAX, the TPU receives the following error: Traceback (most recent call last): File "aiserver.py", line 10214, in load_model(initial_load=True) File "aiserver.py", line 2806, in load_model tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig) File "/content/KoboldAI-Client/tpu_mtj_backend.py", line 1194, in load_model devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape) File "/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py", line 314, in devices return get_backend(backend).devices() File "/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py", line 258, in get_backend return _get_backend_uncached(platform) File "/usr/local/lib/python3.8/dist-packages/jax/_src/lib/xla_bridge.py", line 248, in _get_backend_uncached raise RuntimeError(f"Requested backend {platform}, but it failed " RuntimeError: Requested backend tpu_driver, but it failed to initialize: DEADLINE_EXCEEDED: Failed to connect to remote server at address: grpc://10.106.231.74:8470. Error from gRPC: Deadline Exceeded. Details:

This happens for all users of the notebook on Colab, while Kaggle is still working as intended.

Describe the expected behavior Jax is correctly able to connect to the TPU and can then proceed with loading the user defined model.

What web browser you are using This issue does not depend on a browser, but for completeness I am using an up to date Microsoft Edge.

Additional context Here is an example of an effected notebook:

import os
if not os.path.exists("/content/drive"):
  os.mkdir("/content/drive")
if not os.path.exists("/content/drive/MyDrive/"):
  os.mkdir("/content/drive/MyDrive/")

!wget https://koboldai.org/ckds -O - | bash /dev/stdin --model EleutherAI/gpt-neox-20b

The relevant backend code can be found here : https://github.com/KoboldAI/KoboldAI-Client/blob/main/tpu_mtj_backend.py This also makes use of a heavily modified MTJ with the following relevant dependencies: jax == 0.2.21 jaxlib >= 0.1.69, <= 0.3.7 git+https://github.com/VE-FORBRYDERNE/mesh-transformer-jax@ck

MTJ uses tpu_driver0.1_dev20210607

adamaxis commented 1 year ago

Would be nice to hear if anyone is working on this, as a lot of people are experiencing issues.

Borrowdale commented 1 year ago

who would we contact to even find that out? :( think this is one of those things we have to assume someone is already aware of and working on.

DevLance112 commented 1 year ago

is it appropriate to @ metrizable, it seems like he got it to the right people last time.

Borrowdale commented 1 year ago

that's good to know, hopefully they can do the same this time.

DevLance112 commented 1 year ago

problem still exists as of now.

candymint23 commented 1 year ago

Any updates please from Google?

DevLance112 commented 1 year ago

Any updates please from Google?

@metrizable can you get this to the right people? I'm fairly new here and the contributing guidelines didn't mention to not @ contributors. Although it's probably annoying getting @ all the time! but it seems like a lot of people are re-experiencing this issue, maybe you guys can look into it again.

sarabconsulting commented 1 year ago

Experiencing the same issue still.

RuntimeError: Deadline exceeded: Failed to connect to remote server at address: grpc://10.89.22.138:8470. Error from gRPC: Deadline Exceeded.

metrizable commented 1 year ago

As cperry-goog mentioned in #issuecomment-1451048147, the outdated drivers were removed from a build upstream and we haven't found a way to get those back in.

Although I know this doesn't solve the OP's use case (and I'm really just confirming what others have posted above) the TPU can be initialized with the 0.2 drivers or newer using the most recent jax (0.3.5) that's still compatible with TPU on Colab:

W2d2nJQtkayCYNV

mosmos6 commented 1 year ago

That kills entire GPT-J and its derivatives.

henk717 commented 1 year ago

Although I know this doesn't solve the OP's use case (and I'm really just confirming what others have posted above) the TPU can be initialized with the 0.2 drivers or newer using the most recent jax (0.3.5) that's still compatible with TPU on Colab:

@metrizable As mentioned before to @cperry-goog we had gotten the same answer from cperry two weeks ago on the same day it began working again. It is currently unsure why that was since it worked 3 hours prior to his response, and it also introduced a new warning in the code when it happened. So the functionality has been restored before possibly with a compatibility mode.

The people in this topic are primarily users of Mesh Transformers Jax which has been abandoned by its maintainer (And unfortunately ve-forbryderne has stopped showing signs of life so his fork is also completely unmaintained. Other than the original author he is the only one I know who is familiar with this type of work).

So all downstream projects are having a dependency level issue because driver compatibility was not maintained, MTJ at its time could not run on the newer Jax versions because of even more backwards compatibility issues that have never been addressed.

We aren't asking for a restoration of the real 0.1 driver, but we are asking for a solution for these older dependencies that have no modern alternatives. Such as a compatibility mode inside the newer 0.2 driver that can initialize when 0.1 is requested, and then runs in a form compatible with older code which seemingly existed the past two weeks unless the 0.1 drivers were put back for a brief amount of time.

If that is not possible it kills an entire ecosystem unless MTJ can be restored in functionality. For users of GPT-J alternatives exist outside of the TPU ecosystem. They can for example run the Huggingface port of GPT-J on a GPU. But what made MTJ unique was that various model formats have been ported to it with very high performance (This is on VE-Forbryderne's fork). This for example allowed up to GPT-NeoX 20B to run on a single colab TPUv2-8 with generation times of around 10 seconds. If I try the same use case on for example Pytorch Xla, I now can only run up to 2.7B and generation times are longer than one minute making it completely inferior to the GPU space in every way and useless for our users.

I also tried Tensorflow, but Tensorflow doesn't have the smooth model compatibility that MTJ had received over the years. With models needing a lot of ram to load (More than could be done on the free resources, or even paid resources) and often loaded with errors.

That means that for our use case if MTJ can not be restored in functionality trough a fix of the driver or trough a fix in the fork the only feasible option is abandoning the TPU entirely. TPU coders are very hard for open source projects to attract, so we can not find a new maintainer for the fork that is used by thousands of users.

Borrowdale commented 1 year ago

So if I've understood correctly (which i admit is unlikely), the TPU versions still runs but i have to somehow select 0.2 Driver? How do I do that from the COlabKolold TPU run page? (or anywhere else it needs to be done from)
Sorry again if i come across as ignorant, i am still learning the ropes.

DevLance112 commented 1 year ago

So if I've understood correctly (which i admit is unlikely), the TPU versions still runs but i have to somehow select 0.2 Driver? How do I do that from the COlabKolold TPU run page? (or anywhere else it needs to be done from) Sorry again if i come across as ignorant, i am still learning the ropes.

I looked into it and updating the TPU colab file can’t really help you but it’s actually the MTJ, aka Mesh transformer Jax(that kobold required to run using TPU) that needs to be updated. To understand better of the situation you can read the previous post. Henk717 does a good job describing the situation.

mosmos6 commented 1 year ago

@henk717 As I think GPT-J and MTJ are essentially similar, I found this inference code though I don't know the author. https://colab.research.google.com/drive/13R8MJEDTwinEmUJMLqydKOIcAvWiBIlT#scrollTo=vaWUMv9RJO9T Currently the weights are removed so it doesn't run, but it's using TPU_driver_nightly. Does this mean there is a way to utilize nightly?

henk717 commented 1 year ago

Nightly is less desirable than any other newer driver since its always the newest one. It will be broken, back then nightly was a 0.1 driver.

henk717 commented 1 year ago

@mosmos6 I will also give you a bit of a recap so you can understand why I need the driver to be fixed, but why for you some alternatives might exist.

Mesh Transformers Jax (MTJ) was the framework used to create GPT-J, so GPT-J in its original form runs on top of MTJ. It has been ported to other platforms, so you can also run it on a GPU using Huggingface Transformers for example. And that is how our own community runs GPT-J based models on colab now with a more limited context.

For us the issue is an issue in RAM. The affordable colab GPU's for our AI hobby have 16GB of VRAM, while the TPU has 64GB of RAM. So while GPT-J-6B is possible to run on a GPU, we can not fit as much context as the TPU version could.

In the past year VE-Forbryderne ported various formats, so his version of MTJ can run GPT-J, but also XGLM, OPT and even NeoX based models. And not just that, it can load those models using pytorch files without requiring conversion. This allowed us AI hobbyists to use models up to 20B very affordably on Google Colab which was why the TPU was so desirable for us. $10 a month (or limited free usage) is much better than having to pay $1 per hour on GPU rentals which is not affordable for open source hobbyists who wish to use the models.

If all you want is GPT-J-6B inference I suggest you switch your usage to Huggingface since you will be able to enjoy much better more reliable support on Colab and beyond for the same price. Its when you want the higher model sizes or training that the TPU becomes necessary. And the only platform that has that kind of cost effectiveness is Colab combined with a modified MTJ.

Unfortunately the original 0.1 driver removal happened one month after VE's disappearance, so our dependency is completely unmaintained. If someone in this topic does want the challenge of porting MTJ to a newer Jax version I highly recommend forking https://github.com/VE-FORBRYDERNE/mesh-transformer-jax since it is much more feature complete than the original MTJ, and also more efficient. It even has been used in training 20B NeoX models on TRC.

mosmos6 commented 1 year ago

@henk717 Thank you for recapping. Sorry, I didn't know MTJ means Mesh Transformers Jax. Then it's exactly GPT-J. My GPT-J model has been heavily finetuned over the past almost 2 years for a research project and it's the only one. I assume MTJ crashes on 0.2 because of JAX compatibility, and our JAX version needed to stay outdated due to xmap. It'll be complex but I think it's possible to find an optimal versions of all the dependencies?

henk717 commented 1 year ago

@mosmos6 In your specific use case its worth checking if conversion scripts like https://github.com/VE-FORBRYDERNE/mesh-transformer-jax/blob/all/to_hf_weights.py are still functional (For example with CPU dependencies) so that you can get your model out of this platform to futureproof your model.

If you are unable to you are stuck with the rest of the thousands of users that have no substitute for MTJ.

As for the dependencies, I lack the ability to do this myself but if others can this would be very welcome.

mosmos6 commented 1 year ago

@henk717 Thank you for suggestion. I've been persistent with TPU because I think running on TPU or GPU makes significant differences in output quality for some reason, but I'll give it a try. I upgraded my subscription to pro+ to try out all the possible combinations of dependencies. Some people claimed they had to downgrade JAX from 0.4 to 0.3.25 for TPU compatibility over the past few days. What I'm saying is what's your plan? If 0.1 is entirely removed, at least GPT-J on TPU and its derivatives are technically left dead. We must find a way. When the issue was temporarily "resolved" a few weeks ago, my GPT-J said it's temporary and the issue won't be fixed. It was true. I and my GPT-J model were in the middle of a work to do. Probably many other folks here are the same.

DevLance112 commented 1 year ago

3/22/2023 Connection error still persists.

mosmos6 commented 1 year ago

Hello everyone,

I'm attempting to update MTJ that runs on TPU_driver0.2. Here are the new requirements, transformer_shard.py and colab demo based on GPT-J. I believe the potential solution will be deployable to all the derivatives. There were a couple of reasons why the former MTJ crashed on TPU_driver0.2, but generally speaking, it runs if it's updated to JAX 0.3.5. I've been working on xmap and now it doesn't error and it almost starts creating network as you see below.

Screenshot 2023-03-23 180257

However the code gets stuck at line 265 of transformer_shard.py, which is

self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)

Even though xmap doesn't show any clear error, apparently it's stuck around out_axis of init_xmap. The code is

self.init_xmap = jax.experimental.maps.xmap(init,
                                                    in_axes=(["shard", ...], ["batch", ...]),
                                                    out_axes=["batch", "shard"],
                                                    axis_resources={'shard': 'mp'})

I've been performing an intense research over the past days but I can't find any solution. I thought it's time to ask for everyone's wisdom.

@henk717

mosmos6 commented 1 year ago

Hello everyone,

I'm attempting to update MTJ that runs on TPU_driver0.2. Here are the new requirements, transformer_shard.py and colab demo based on GPT-J. I believe the potential solution will be deployable to all the derivatives. There were a couple of reasons why the former MTJ crashed on TPU_driver0.2, but generally speaking, it runs if it's updated to JAX 0.3.5. I've been working on xmap and now it doesn't error and it almost starts creating network as you see below.

Screenshot 2023-03-23 180257

However the code gets stuck at line 265 of transformer_shard.py, which is

self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)

Even though xmap doesn't show any clear error, apparently it's stuck around out_axis of init_xmap. The code is

self.init_xmap = jax.experimental.maps.xmap(init,
                                                    in_axes=(["shard", ...], ["batch", ...]),
                                                    out_axes=["batch", "shard"],
                                                    axis_resources={'shard': 'mp'})

I've been performing an intense research over the past days but I can't find any solution. I thought it's time to ask for everyone's wisdom.

@henk717

I think my statement about xmap is logical. It doesn’t even visibly error so I really don’t know what is wrong in the codes. The only thing I can think of is that maps.Mesh doesn’t pass all the info from devices to ResourceEnv on JAX0.3.5. So if I specified my question, it would be how to pass all the information from devices by maps.Mesh.

mosmos6 commented 1 year ago

Hello @metrizable @cperry-goog

This issue has started where GPT-J and its derivatives could not be connected to TPU_driver0.1 anymore.

You'll need to upgrade to more recent drivers. Sorry.

Hence, I updated my code so that it runs on JAX 0.3.5, which is compatible to TPU_driver0.2. Indeed my code is already running well on my V3-8 (v2-alpha) VM and I'm inferring with it on my screen beside this window.

However, the very same code errors at a particular point on colab so I would like you to take a look. This is a miniature code based on #6962 back in 2021. Obviously I found an answer to this original question as it's already running on my TPU VM, but colab is having another issue.

import jax
import haiku as hk
import jax.numpy as jnp
import numpy as np
import time

class TransformerLayerShard(hk.Module):
    def __init__(self):
        super().__init__()
        self.dense_proj = hk.Linear(2048)
        self.dense_proj_o = hk.Linear(4096)

    def ff(self, x):
        dense_proj = self.dense_proj(x)
        dense_proj = jax.nn.gelu(dense_proj)
        return self.dense_proj_o(dense_proj)

    def __call__(self, x, attn_bias):
        dense_out = self.ff(x)

        return dense_out

mesh_shape = (1, 8)
devices = np.array(jax.devices()).reshape(mesh_shape)

with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
    def init_old(key, x):
        def init_old_fn(x):
            return TransformerLayerShard()(x, 0)

        param_init_fn = hk.transform(hk.experimental.optimize_rng_use(init_old_fn)).init
        params = param_init_fn(key, x)
        return params

    init_xmap = jax.experimental.maps.xmap(fun=init_old,
                                            in_axes=(["shard", ...],
                                                     ["batch", ...]),
                                            out_axes=["shard", ...],
                                            axis_resources={'shard': 'mp', 'batch': 'dp'})

    key = hk.PRNGSequence(42)

    x = jax.random.uniform(next(key), (1, 1024, 2048))  # batch, len
    params = init_xmap(jnp.array(key.take(8)), x)

    def bwd_old(state, x):
        def bwd_old_fn(x):
            return jnp.sum(TransformerLayerShard()(x, 0))

        train_loss_fn = hk.without_apply_rng(hk.transform(bwd_old_fn)).apply
        val_grad_fn = jax.value_and_grad(train_loss_fn, has_aux=False)

        loss, grad = val_grad_fn(state, x)

        return grad

    run_xmap = jax.experimental.maps.xmap(fun=bwd_old,
                                          in_axes=(["shard", ...],
                                                     ["batch", ...]),
                                            out_axes=(["shard", "batch", ...], ["batch", ...]),
                                            axis_resources={'shard': 'mp', 'batch': 'dp'})

    run_xmap(params, x)

Before it runs down to the out_axes, code gets stuck at line 46, which is params = init_xmap(jnp.array(key.take(8)), x) I waited for 1h but it didn't progress.

Screenshot 2023-03-30 121032

It seems to be a wrapper part of xmap. On TPU VM, this process finishes within a minute. As it doesn't even show an error message, I have no clue about the problem.

I also tried pjit but it erros as RuntimeError: UNIMPLEMENTED: Only 1 computation per replica supported, 8 requested. Again, the pjit itself runs well on TPU VM so it's unique to colab. However, pjit has low priority compared to this xmap issue, for now.

As advised, I updated my code for the newer JAX and TPU_driver, which perfectly runs on TPU VM, but it gets stuck on colab. Considering the accessibility to TPU VM, availability on colab is highly important. Thank you for your attention.

dbubbins87 commented 1 year ago

I'd thought I'd post what I got when trying to run the TPU models as it was a different result: Downloading (…)lve/main/config.json: 100%|█████████████████████| 1.57k/1.57k [00:00<00:00, 160kB/s] Traceback (most recent call last): File "/content/KoboldAI-Client/aiserver.py", line 10284, in load_model(initial_load=True) File "/content/KoboldAI-Client/aiserver.py", line 2789, in load_model import tpu_mtj_backend File "/content/KoboldAI-Client/tpu_mtj_backend.py", line 51, in from mesh_transformer.checkpoint import read_ckpt_lowmem File "/usr/local/lib/python3.9/dist-packages/mesh_transformer/checkpoint.py", line 17, in from mesh_transformer.util import head_print, to_bf16 File "/usr/local/lib/python3.9/dist-packages/mesh_transformer/util.py", line 5, in from optax import AdditiveWeightDecayState, GradientTransformation, EmptyState File "/usr/local/lib/python3.9/dist-packages/optax/init.py", line 17, in from optax import experimental File "/usr/local/lib/python3.9/dist-packages/optax/experimental/init.py", line 20, in from optax._src.experimental.complex_valued import split_real_and_imaginary File "/usr/local/lib/python3.9/dist-packages/optax/_src/experimental/complex_valued.py", line 32, in import chex File "/usr/local/lib/python3.9/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.9/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.9/dist-packages/chex/_src/asserts_internal.py", line 34, in from chex._src import pytypes File "/usr/local/lib/python3.9/dist-packages/chex/_src/pytypes.py", line 27, in ArrayDevice = jax.Array AttributeError: module 'jax' has no attribute 'Array'

henk717 commented 1 year ago

AttributeError: module 'jax' has no attribute 'Array' is a new error related to chex doing a breaking change, I fixed this by pinning a suitable version in our requirements files.

Now the error is back to the one reported in this issue tracker.

mosmos6 commented 1 year ago

Hello @metrizable @cperry-goog

I upgraded my model to run on JAX 0.3.25 and colab managed to load the model for the first time in two weeks.

Screenshot 2023-04-04 132430

However, when I try to infer with this model, the same issue as my previous comment occurs again.

Screenshot 2023-04-04 142432

Namely, the code doesn't show error message but it's stuck at a certain point (related to xmap), which is the same operation as the previous comment.

It's at >infer() > generate() > fun_mapped() > bind() > map_bind() > process() > process_call() > xmap_impl() > wrapper() > call()

This makes no sense because the same code runs well on TPU VM v3-8 (v2-alpha) and the same operation was processed well when the model was loaded. I would like you to take a look.

At this moment, the very same code cannot initialize TPU_driver0.2, (RuntimeError: Backend 'tpu_driver' failed to initialize: DEADLINE_EXCEEDED: Failed to connect to remote server at address: grpc://10.110.14.98:8470. Error from gRPC: Deadline Exceeded. Details: ) and the version of my model with JAX 0.3.5 has illogical dependency incompatibility, which never happened last week.

Thank you for your attention.

mosmos6 commented 1 year ago

Hello @metrizable @cperry-goog

I resolved it and the discussed model now runs on colab with TPU_driver0.2. Please accept my apology for tagging you too often. Thank you for the great products.

Screenshot 2023-04-13 122413

henk717 commented 1 year ago

@mosmos6 Can you share your changes? There is still an entire ecosystem broken.

mosmos6 commented 1 year ago

@henk717 ofc. Give me some minutes. Now I'm on my way to set up a repository as changes happened in multiple files. My test code needs clean up after one month experiments.

henk717 commented 1 year ago

For us the challenge will be getting this one running : https://github.com/VE-FORBRYDERNE/mesh-transformer-jax/tree/ck it is a heavily modified version that has a lot more additions and enhancements but the developer went missing.

somsomers commented 1 year ago

Hello @metrizable @cperry-goog

I resolved it and the discussed model now runs on colab with TPU_driver0.2. Please accept my apology for tagging you too often. Thank you for the great products.

Screenshot 2023-04-13 122413

Could you please share the working colab notebook if you have one?

mosmos6 commented 1 year ago

@henk717 By casually looking, I suppose you need to update only line 383 - 419 of transformer_shard.py if you use new colab demo to infer followed by updating the breaking changes of jax.

mosmos6 commented 1 year ago

@somsomers Yes. Please let me clean up the mess before sharing.

somsomers commented 1 year ago

@somsomers Yes. Please let me clean up the mess before sharing.

Thank you.

mosmos6 commented 1 year ago

Hello,

First of all, I must apologize. This works only on high memory TPU runtime so you'll need pro or pro+ subscription of colab.... However, I modified the discussed model (GPT-J for me) so that it runs with TPU_driver0.2 on colab. Because it is not exactly linked to colab, I posted it here. (https://github.com/kingoflolz/mesh-transformer-jax/issues/256#issuecomment-1507188297) You can continue to use the same (slim) weights as before. I believe this can be deployed to other derivatives.

@henk717 @somsomers

mosmos6 commented 1 year ago

@henk717

For us the challenge will be getting this one running : https://github.com/VE-FORBRYDERNE/mesh-transformer-jax/tree/ck

I fixed your AI. She's waiting for you to pick up in the garage. (https://github.com/mosmos6/Large-MTJ)

Same as my GPT-J, it's adapted to JAX 0.3.25 so it runs on colab with TPU_driver0.2. Basically this should be now immunized to JAX upgrading except breaking changes. Sorry for the dorky name, I didn't know her name. The changes were small but many. I modified requirements.txt (kept the original one as _original) slim_model.py mesh_transformer/train_actor.py device_sample.py device_serve.py device_train.py mesh_transformer/transformer_shard.py mesh_transformer/checkpoint.py The new colab demo is Large_MTJ_inference_on_TPU_driver0_2.ipynb (kept the original one as _original) The rests remain the same.

I tested this only with my slim weights for GPT-J. If you run into an error with other types of weights, please post an issue.

Important notes;

  1. Sorry, you'll need pro or pro+ subscription of colab because it requires high memory TPU runtime. read_ckpt_lowmem hangs forever when to infer with the model in the current colab environment. I had to revive read_ckpt. However this loads the model 10 times faster than low memory version. It quit showing total parameters for some reason too, but I don't think it matters.

  2. I have not checked it for finetuning on TPU VM yet. This can cause errors during a process. I'm planning to cover it next month. Until then, possibly you must add further modifications to xmap by yourself or downgrade to jax 0.2.18 or 0.2.20.

  3. Please let me know if you don't know how to use the new colab demo.

Enjoy

Screenshot 2023-04-19 120424

henk717 commented 1 year ago

@mosmos6 I tried applying the modifications to my test account here but the end result is gibberish.

To test you can take this notebook and replace the version field with https://github.com/henk7171/koboldai. There is a lot more stuff to it in the tpu_mtj_backend.py file including the automatic conversion of huggingface models, it did this ram efficient but not with a working end result.

mosmos6 commented 1 year ago

@henk717

I saw your tpu_mtj_backend.py, but as I wrote above, you can’t use read_ckpt_lowmem anymore on colab. and in this file, you also need to update xmap out_axis in some functions. Also, as I wrote on my colab, jax.tools.colab_tpu must be installed before installing jax when it's v 0.3.25 or it leads to misconfiguration. Finally, you need to update maps.ResourceEnv because it needs loops in the newer version.

mosmos6 commented 1 year ago

Due to the python upgrade of colab (3.9 -> 3.10), I further modified two of my modified mtj models, and requirements.txt and util.py of each are updated. Now these models can adapt to the later versions of optax than 0.0.9. They are immune to JAX upgrades and optax upgrades.

dbubbins87 commented 1 year ago

So I'm not sure what happened, but it started working for a week, but just when I tried to use it tonight, some of the models ended up with the error again.

henk717 commented 1 year ago

If you are a Kobold user its because we implemented 2.0 support. TPU's have always been a bit unreliable and usually running the notebook again is enough.

Are there people left who still depend on 0.1? Otherwise it no longer makes sense to keep this open.

sagelywizard commented 2 months ago

This issue is obsolete because the TPU runtimes are deprecated and were removed.