rapidsai / cuml

cuML - RAPIDS Machine Learning Library
https://docs.rapids.ai/api/cuml/stable/
Apache License 2.0
4.13k stars 526 forks source link

[BUG] Deterministic UMAP is not deterministic #5099

Open zbjornson opened 1 year ago

zbjornson commented 1 year ago

Describe the bug UMAP with deterministic=true and random_state=something give very different results.

Steps/Code to reproduce bug

from cuml.manifold.umap import UMAP
from cuml.datasets import make_blobs
from matplotlib import pyplot as plt
import numpy as np

X, y = make_blobs(n_samples=10000, centers=8, n_features=4, random_state=0, dtype=np.float32)

umap = UMAP(n_neighbors=15, random_state=1)

Y1 = umap.fit_transform(X).get()
plt.scatter(Y1[:, 0], Y1[:, 1])
plt.show()

Y2 = umap.fit_transform(X).get()
plt.scatter(Y2[:, 0], Y2[:, 1])
plt.show()

Expected behavior Close to identical images. The docs say:

Setting a random_state will enable consistency of trained embeddings, allowing for reproducible results to 3 digits of precision

Environment details (please complete the following information):

`conda list` ``` # packages in environment at /home/studio-lab-user/.conda/envs/studiolab: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_gnu conda-forge aiobotocore 2.1.2 pypi_0 pypi aiohttp 3.8.3 py39hb9d737c_1 conda-forge aioitertools 0.11.0 pypi_0 pypi aiosignal 1.3.1 pyhd8ed1ab_0 conda-forge amzn-sagemaker-scheduler 0.1.0 pypi_0 pypi amzn_sagemaker_studiolab 0.1.5 0 file:///opt/amazon/sagemaker/packages/conda anyio 3.6.2 pyhd8ed1ab_0 conda-forge argon2-cffi 21.3.0 pyhd8ed1ab_0 conda-forge argon2-cffi-bindings 21.2.0 py39hb9d737c_3 conda-forge asttokens 2.1.0 pyhd8ed1ab_0 conda-forge async-timeout 4.0.2 pyhd8ed1ab_0 conda-forge attrs 22.1.0 pyh71513ae_1 conda-forge aws-embedded-metrics 1.0.7 pypi_0 pypi babel 2.11.0 pyhd8ed1ab_0 conda-forge backcall 0.2.0 pyh9f0ad1d_0 conda-forge backports 1.0 py_2 conda-forge backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge beautifulsoup4 4.11.1 pyha770c72_0 conda-forge bleach 5.0.1 pyhd8ed1ab_0 conda-forge botocore 1.23.24 pypi_0 pypi brotlipy 0.7.0 py39hb9d737c_1005 conda-forge bzip2 1.0.8 h7f98852_4 conda-forge ca-certificates 2022.9.24 ha878542_0 conda-forge certifi 2022.9.24 pyhd8ed1ab_0 conda-forge cffi 1.15.1 py39he91dace_2 conda-forge charset-normalizer 2.1.1 pyhd8ed1ab_0 conda-forge colorama 0.4.6 pyhd8ed1ab_0 conda-forge croniter 1.3.7 pypi_0 pypi cryptography 38.0.3 py39h3ccb8fc_0 conda-forge debugpy 1.6.3 py39h5a03fae_1 conda-forge decorator 5.1.1 pyhd8ed1ab_0 conda-forge defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge entrypoints 0.4 pyhd8ed1ab_0 conda-forge executing 1.2.0 pyhd8ed1ab_0 conda-forge flit-core 3.8.0 pyhd8ed1ab_0 conda-forge frozenlist 1.3.3 py39hb9d737c_0 conda-forge fsspec 2022.2.0 pypi_0 pypi gitdb 4.0.9 pyhd8ed1ab_0 conda-forge gitpython 3.1.29 pyhd8ed1ab_0 conda-forge greenlet 2.0.1 pypi_0 pypi idna 3.4 pyhd8ed1ab_0 conda-forge importlib-metadata 5.0.0 pyha770c72_1 conda-forge importlib_resources 5.10.0 pyhd8ed1ab_0 conda-forge ipykernel 6.17.1 pyh210e3f2_0 conda-forge ipython 8.6.0 pyh41d4057_1 conda-forge ipython_genutils 0.2.0 py_1 conda-forge jedi 0.18.1 pyhd8ed1ab_2 conda-forge jinja2 3.1.2 pyhd8ed1ab_1 conda-forge jmespath 0.10.0 pypi_0 pypi json5 0.9.5 pyh9f0ad1d_0 conda-forge jsonschema 4.17.0 pyhd8ed1ab_0 conda-forge jupyter-lsp 1.5.1 pyhd8ed1ab_0 conda-forge jupyter-scheduler 1.1.4 pypi_0 pypi jupyter-server-mathjax 0.2.6 pyhc268e32_0 conda-forge jupyter-server-proxy 3.2.2 pyhd8ed1ab_0 conda-forge jupyter_client 7.4.5 pyhd8ed1ab_0 conda-forge jupyter_core 5.0.0 py39hf3d152e_0 conda-forge jupyter_server 1.23.2 pyhd8ed1ab_0 conda-forge jupyter_telemetry 0.0.5 py_0 conda-forge jupyterlab 3.5.0 pyhd8ed1ab_0 conda-forge jupyterlab-git 0.34.2 pyhd8ed1ab_0 conda-forge jupyterlab-lsp 3.10.2 pyhd8ed1ab_0 conda-forge jupyterlab_pygments 0.2.2 pyhd8ed1ab_0 conda-forge jupyterlab_server 2.16.3 pyhd8ed1ab_0 conda-forge jupyterlab_widgets 1.1.1 pyhd8ed1ab_0 conda-forge ld_impl_linux-64 2.39 hc81fddc_0 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-ng 12.2.0 h65d4601_19 conda-forge libgomp 12.2.0 h65d4601_19 conda-forge libnsl 2.0.0 h7f98852_0 conda-forge libsodium 1.0.18 h36c2ea0_1 conda-forge libsqlite 3.39.4 h753d276_0 conda-forge libstdcxx-ng 12.2.0 h46fd767_19 conda-forge libuuid 2.32.1 h7f98852_1000 conda-forge libzlib 1.2.13 h166bdaf_4 conda-forge markupsafe 2.1.1 py39hb9d737c_2 conda-forge matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge mistune 2.0.4 pyhd8ed1ab_0 conda-forge multidict 6.0.2 py39hb9d737c_2 conda-forge nb_conda_kernels 2.3.1 py39hf3d152e_2 conda-forge nbclassic 0.4.8 pyhd8ed1ab_0 conda-forge nbclient 0.7.0 pyhd8ed1ab_0 conda-forge nbconvert 7.2.5 pyhd8ed1ab_0 conda-forge nbconvert-core 7.2.5 pyhd8ed1ab_0 conda-forge nbconvert-pandoc 7.2.5 pyhd8ed1ab_0 conda-forge nbdime 3.1.1 pyhd8ed1ab_0 conda-forge nbformat 5.7.0 pyhd8ed1ab_0 conda-forge ncurses 6.3 h27087fc_1 conda-forge nest-asyncio 1.5.6 pyhd8ed1ab_0 conda-forge notebook 6.5.2 pyha770c72_1 conda-forge notebook-shim 0.2.2 pyhd8ed1ab_0 conda-forge openssl 3.0.7 h166bdaf_0 conda-forge packaging 21.3 pyhd8ed1ab_0 conda-forge pandoc 2.19.2 h32600fe_1 conda-forge pandocfilters 1.5.0 pyhd8ed1ab_0 conda-forge parso 0.8.3 pyhd8ed1ab_0 conda-forge pexpect 4.8.0 pyh1a96a4e_2 conda-forge pickleshare 0.7.5 py_1003 conda-forge pip 21.2.4 pyhd8ed1ab_0 conda-forge pkgutil-resolve-name 1.3.10 pyhd8ed1ab_0 conda-forge platformdirs 2.5.2 pyhd8ed1ab_1 conda-forge prometheus_client 0.15.0 pyhd8ed1ab_0 conda-forge prompt-toolkit 3.0.32 pyha770c72_0 conda-forge psutil 5.9.4 py39hb9d737c_0 conda-forge ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge pycparser 2.21 pyhd8ed1ab_0 conda-forge pydantic 1.10.2 pypi_0 pypi pygments 2.13.0 pyhd8ed1ab_0 conda-forge pyopenssl 22.1.0 pyhd8ed1ab_0 conda-forge pyparsing 3.0.9 pyhd8ed1ab_0 conda-forge pyrsistent 0.19.2 py39hb9d737c_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 conda-forge python 3.9.13 h2660328_0_cpython conda-forge python-dateutil 2.8.2 pyhd8ed1ab_0 conda-forge python-fastjsonschema 2.16.2 pyhd8ed1ab_0 conda-forge python-json-logger 2.0.1 pyh9f0ad1d_0 conda-forge python_abi 3.9 2_cp39 conda-forge pytz 2022.6 pyhd8ed1ab_0 conda-forge pyzmq 24.0.1 py39headdf64_1 conda-forge readline 8.1.2 h0f457ee_0 conda-forge requests 2.28.1 pyhd8ed1ab_1 conda-forge ruamel.yaml 0.17.21 py39hb9d737c_2 conda-forge ruamel.yaml.clib 0.2.7 py39hb9d737c_0 conda-forge s3fs 2022.2.0 pypi_0 pypi send2trash 1.8.0 pyhd8ed1ab_0 conda-forge setuptools 65.5.1 pyhd8ed1ab_0 conda-forge simpervisor 0.4 pyhd8ed1ab_0 conda-forge six 1.16.0 pyh6c4a22f_0 conda-forge smmap 3.0.5 pyh44b312d_0 conda-forge sniffio 1.3.0 pyhd8ed1ab_0 conda-forge soupsieve 2.3.2.post1 pyhd8ed1ab_0 conda-forge sqlalchemy 1.4.44 pypi_0 pypi sqlite 3.39.4 h4ff8645_0 conda-forge stack_data 0.6.1 pyhd8ed1ab_0 conda-forge terminado 0.17.0 pyh41d4057_0 conda-forge tinycss2 1.2.1 pyhd8ed1ab_0 conda-forge tk 8.6.12 h27826a3_0 conda-forge tomli 2.0.1 pyhd8ed1ab_0 conda-forge tornado 6.2 py39hb9d737c_1 conda-forge traitlets 5.5.0 pyhd8ed1ab_0 conda-forge typing-extensions 4.4.0 hd8ed1ab_0 conda-forge typing_extensions 4.4.0 pyha770c72_0 conda-forge tzdata 2022f h191b570_0 conda-forge urllib3 1.26.11 pyhd8ed1ab_0 conda-forge wcwidth 0.2.5 pyh9f0ad1d_2 conda-forge webencodings 0.5.1 py_1 conda-forge websocket-client 1.4.2 pyhd8ed1ab_0 conda-forge wheel 0.38.4 pyhd8ed1ab_0 conda-forge wrapt 1.14.1 pypi_0 pypi xz 5.2.6 h166bdaf_0 conda-forge yarl 1.8.1 py39hb9d737c_0 conda-forge zeromq 4.3.4 h9c3ff4c_1 conda-forge zipp 3.10.0 pyhd8ed1ab_0 conda-forge ```

Additional context My primary goal was actually to address "fly aways" (points that fly away from the main blobs) in UMAP that I assume are due to numerical instability or FP errors: #1121 / #3467.

beckernick commented 1 year ago

cc @cjnolet , as I believe you've looked at this in the past and had some insight

cjnolet commented 1 year ago

We do run several pytests for reproducibility in cuML which verify near-exact match when random_state is set.

There have been some updates to both the random APIs and spectral clustering APIs in RAFT so it's possible one of those could be the culprit. One thing we can do to test that hypothesis would be to set init='random' so that it bypasses the spectral clustering step.

beckernick commented 1 year ago

Cross-linking and copying another user comment about reproducibility challenges in a different setting.

You can call BERTopic.transform() against unseen documents and it will approximate the clustering using the already established model. One thing to keep an eye out for is the stochastic behavior of UMAP. Every time it is run you will get different results. For many applications this isn't a big deal - but might be a consideration in your case. One easy way around this is to set the UMAP model's random_state to a fixed value - this will ensure reproducibility. One more issue about this is that the UMAP output is specific to a given platform - you can't move models from one architecture to another. Lastly this is also a problem with cuML if you get into that - cuML's implementation of UMAP won't provide repeatable results even when setting the random_state

Originally posted by @drob-xx in https://github.com/MaartenGr/BERTopic/issues/940#issuecomment-1385824017

drob-xx commented 1 year ago

@cjnolet So in the case I'm looking at "near-exact" match is unusable. The workaround in my case is to simply capture the UMAP reduction and re-use as needed. So this isn't a show stopper - however I'm not sure I think it is a good idea to expose random_state which doesn't reproduce results exactly. Perhaps a warning? I spent hours trying to figure out what was going on.

cjnolet commented 1 year ago

@drob-xx, by "near-exact match" I mean exact to several decimal places. As mentioned, I think the reproducibility issue is the result of a bug in the spectral embedding computation, which we have plans to fix but it's nontrivial. For this reason, our reproducibility pytests use the random initialization method and verify the results are exact. Depending on the size of your dataset, the random initialization could provide the same (and in some cases better) results. Are you able to try using init="random" to see if your results become reproducible?

cjnolet commented 1 year ago

In the meantime, I definitely agree that we should add a note/warning about this to the UMAP docs (and in the code).

drob-xx commented 1 year ago

@cjnolet Understood. But right now random_seed isn't working the way I would expect. So while I get the three digit tolerance, and that sounds like a lot, in my case it doesn't provide any comfort. Great product by the way - the speedup in HDBSCAN and UMAP are off. the. charts. It feels alien. However, in this particular case a warning (in code - b/c I didn't read the docs until later and anyway when I see something called random_seed which works in the "original" one way, I expect it to work the same way in what is billed as a parallel implementation -- "..the only change is in the imports.." Thanks for all your work!

zbjornson commented 1 year ago

Using init="random" indeed fixes the problem, so it's an issue with spectral init as @cjnolet suggested.

from cuml.manifold.umap import UMAP
from cuml.datasets import make_blobs
from matplotlib import pyplot as plt
import numpy as np

X, y = make_blobs(n_samples=10000, centers=8, n_features=4, random_state=0, dtype=np.float32)

umap = UMAP(n_neighbors=15, random_state=1, init="random")

Y1 = umap.fit_transform(X).get()
plt.scatter(Y1[:, 0], Y1[:, 1])
plt.show()

Y2 = umap.fit_transform(X).get()
plt.scatter(Y2[:, 0], Y2[:, 1])
plt.show()

zbjornson commented 1 year ago

There's a great writeup of different UMAP initialization techniques here: https://jlmelville.github.io/uwot/init.html. Scaled PCA might be an easy option (in addition to or in lieu of fixing spectral).

drob-xx commented 1 year ago

@cjnolet @zbjornson So just chiming in that using init="random" seems to fix the problem (thanks!!!). I'm wondering though about what the long-term downsides of relying on cuML UMAP reproducibility will be? I get that the precision s/b within three decimal places - which sounds good - but I wonder if I'm introducing a possible problem down the road for use cases I can't foresee. My package is currently focused on BERTopic which uses UMAP to reduce LLM embeddings and from what I understand this level of precision will be more than enough for that application. I guess that a warning message within cuML would suffice to give people a heads up.

zbjornson commented 1 year ago

@drob-xx informative (non-random) initialization is critical, see https://www.biorxiv.org/content/10.1101/2019.12.19.877522v1. I was unclear earlier; using random init fixes the reproducibility, but is not a viable workaround.

UMAP with random initialization preserves global structure as poorly as t-SNE with random initialization, while t-SNE with informative initialization performs as well as UMAP with informative initialization. Hence, contrary to the claims of Becht et al., their experiments do not demonstrate any advantage of the UMAP algorithm per se, but rather warn against using random initialization.

drob-xx commented 1 year ago

:(. I must say I'm totally confused. I was just coming back here because I was getting some seemingly wacky behavior. Essentially I "tested" doing something like this:

from sentence_transformers import SentenceTransformer
from cuml.manifold import UMAP

def getUMAP():
  return UMAP(n_neighbors=15,
                         n_components=5,
                         min_dist=0.0,
                         metric='cosine',
                         init='random',
                         random_state=42)

docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data'][:200]

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(docs)

umap_models = []
for num_model in range(2):
  umap_models.append(getUMAP())
  umap_models[num_model].fit_transform(embeddings)

print(np.all(umap_models[0].embedding_ == umap_models[1].embedding_))
print(umap_models[0].random_state)
print(umap_models[1].random_state)

Which outputs:

True 6909045637428952499 6909045637428952499

So it is not picking up the 42 value in the constructor, but (I assume) because I'm setting random_state AND setting init=random it is deciding to set its own random_state and producing identical results. So at this point I'll trust that it is broken and proceed accordingly. I appreciate the help and quick responses and am sorry for any confusion.

cjnolet commented 1 year ago

@drob-xx that's a good catch! Yes, there definitely is a bug in there- this conditional is testing for np.uint64 but 42 is an int, not a uint64. We should be testing for general integral type and casting instead.

EDIT: The more I think about the conditional, I think it's doing the right thing. If the value passed in is already a uint64, it uses that as the random_state directly, otherwise it creates a proper RandomState object from it, still using the input value as the seed.

Also, the docs for the random state need to be updated- the original implementation of reproducible results had fixed part of the determinism issue and got the results quite close (to within 3 decimals) but it didn't fully fix the floating point associativity issue. We've since gotten it to an exact match (as demonstrated by the pytest link in my earlier comment).

We use trustworthiness to evaluate the general quality of an embedding that claims to preserve local neighborhood structure. Arguments about initializatoin, global structure, and its effect on the quality of final embeddings aside, we've found (empirically) that random initialization can yield just as good trustworthiness scores as more informed inititlizations, which essentially means that the n_neighbors nearest neighbors after the embeddings are similar to before.

Just to be clear- we are aware of the spectral initialization issue and it will get fixed. I'm proposing random could be an option in the short-term if you are finding the results are reasonable (cuml.metrics.trustworthiness might be able to help determine that).

drob-xx commented 1 year ago

@cjnolet Yes. I went to bed wondering about this - b/c if it didn't use the original passed value as it does then we wouldn't see new rand values that matched on successive invocations.

My intuition is that in my case I don't need to worry too much about the issue that @zbjornson raises re: random initialization and its relationship to global structure, however before I commit to using it I would like to know that I'm not missing something. OTOH people who are using my code in a cuML environment are going to be generally a bit more sophisticated than the average bear, so maybe I can just warn and move on. Any pointers on references that would point to down-the-road problems with my using random_state and init=random until this is fixed would be appreciated.