saeyslab / napari-sparrow

Other
17 stars 0 forks source link

AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11' #146

Closed Lucas-Maciel closed 7 months ago

Lucas-Maciel commented 11 months ago

Hi,

I tried to import sparrow but I had a problem, related to jax and the fact that I'm using windows. I'm attaching the yaml as well.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 import napari_sparrow as nas

File D:\napari-sparrow\napari-sparrow-main\src\napari_sparrow\__init__.py:7
      4 import os
      5 os.environ["USE_PYGEOS"] = "0"
----> 7 from . import io
      8 from . import image as im
      9 from . import shape as sh

File D:\napari-sparrow\napari-sparrow-main\src\napari_sparrow\io\__init__.py:1
----> 1 from ._spatial_data import create_sdata
      2 from ._transcripts import read_resolve_transcripts, read_vizgen_transcripts, read_stereoseq_transcripts, read_transcripts

File D:\napari-sparrow\napari-sparrow-main\src\napari_sparrow\io\_spatial_data.py:15
     12 from spatialdata.models.models import ScaleFactors_t
     13 from spatialdata.transformations import Translation
---> 15 from napari_sparrow.image._image import _add_image_layer
     16 from napari_sparrow.utils.pylogger import get_pylogger
     18 log = get_pylogger(__name__)

File D:\napari-sparrow\napari-sparrow-main\src\napari_sparrow\image\__init__.py:2
      1 from ._segmentation import segment
----> 2 from ._tiling import tiling_correction
      3 from ._contrast import enhance_contrast
      4 from ._minmax import min_max_filtering

File D:\napari-sparrow\napari-sparrow-main\src\napari_sparrow\image\_tiling.py:7
      5 import numpy as np
      6 import squidpy as sq
----> 7 from basicpy import BaSiC
      8 from spatialdata import SpatialData
      9 from spatialdata.models.models import ScaleFactors_t

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\basicpy\__init__.py:7
      4 import os
      6 from basicpy import data
----> 7 from basicpy.basicpy import BaSiC
      9 # Set logger level from environment variable
     10 logging_level = os.getenv("BASIC_LOG_LEVEL", default="INFO").upper()

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\basicpy\basicpy.py:16
     13 from pathlib import Path
     14 from typing import Dict, List, Optional, Tuple, Union
---> 16 import jax.numpy as jnp
     18 # 3rd party modules
     19 import numpy as np

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\jax\__init__.py:35
     30 del _cloud_tpu_init
     32 # Confusingly there are two things named "config": the module and the class.
     33 # We want the exported object to be the class, so we first import the module
     34 # to make sure a later import doesn't overwrite the class.
---> 35 from jax import config as _config_module
     36 del _config_module
     38 # Force early import, allowing use of `jax.core` after importing `jax`.

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\jax\config.py:17
      1 # Copyright 2018 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 
     15 # TODO(phawkins): fix users of this alias and delete this file.
---> 17 from jax._src.config import config  # noqa: F401

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\jax\_src\config.py:24
     21 import threading
     22 from typing import Any, List, Callable, Hashable, NamedTuple, Iterator, Optional
---> 24 from jax._src import lib
     25 from jax._src.lib import jax_jit
     26 from jax._src.lib import transfer_guard_lib

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\jax\_src\lib\__init__.py:92
     89 except ImportError:
     90   utils = None
---> 92 import jaxlib.xla_client as xla_client
     93 import jaxlib.lapack as lapack
     95 import jaxlib.ducc_fft as ducc_fft

File D:\miniconda3\envs\napari-sparrow\envs\napari-sparrow\lib\site-packages\jaxlib\xla_client.py:225
    223 bfloat16 = ml_dtypes.bfloat16
    224 float8_e4m3fn = ml_dtypes.float8_e4m3fn
--> 225 float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11
    226 float8_e5m2 = ml_dtypes.float8_e5m2
    228 XLA_ELEMENT_TYPE_TO_DTYPE = {
    229     PrimitiveType.PRED: np.dtype('bool'),
    230     PrimitiveType.S8: np.dtype('int8'),
   (...)
    248     PrimitiveType.TOKEN: np.dtype(np.object_),
    249 }

AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'

env_sparrow.txt

Thank you Lucas

ArneDefauw commented 11 months ago

Hi Lucas, thanks for opening the issue, have you tried the solution as suggested here: https://github.com/google/jax/issues/17693,

pip install ml_dtypes==0.2.0

I see that you have ml_dtypes version 0.3.0 from your yml.

Let me know if this solution would not work!