cstjean / ScikitLearn.jl

Julia implementation of the scikit-learn API https://cstjean.github.io/ScikitLearn.jl/dev/
Other
547 stars 75 forks source link

Bug with weighted DBSCAN #120

Open nickkeepfer opened 1 year ago

nickkeepfer commented 1 year ago

Problem

I'm using DBSCAN to find clusters in a 3D dataset that varies with time. Every now and again (<5% of the time), DBSCAN fails completely to see a very obvious cluster. It's sometimes possible to make it work by simply circshifting the array, but not always.

There seems to be no clear reason why it fails, it just sometimes does.

Please see the following example (file is included for replication purposes):

using JLD2
using ScikitLearn
using PyCall

# Wrapper for DBSCAN 
DBSCAN = pyimport("sklearn.cluster").DBSCAN

# Load data
f = jldopen("DBSCAN_BUG.jld2")
x, y, z = f["x"], f["y"], f["z"]

# Format data such that each voxel is given as an (x,y,z) coordinate
X = repeat(x',length(y),1,length(z)) .+ 2*maximum(x)
Y = repeat(y,1,length(x),length(z)) .+ 2*maximum(y)
Z = permutedims(repeat(z,1,length(x),length(y)),[3 2 1]) .+ 2*maximum(z)
dat = zeros(length(X[:]),3)
dat[:,1] = X[:]
dat[:,2] = Y[:]
dat[:,3] = Z[:]

# Perform DBSCAN where points are weighted by density array
decomp = DBSCAN(eps=abs(x[1]-x[2]),min_samples=1).fit_predict(dat,sample_weight=f["dens"][:])
dbscan = replace!(reshape(decomp,size(f["dens"])),-1=>0)
dbscan[dbscan.>0] .= 1.0

# Plot DBSCAN results alongside the density 
using CairoMakie
fig = Figure()
ax, hm1 = heatmap(fig[1,1], x, y, f["dens"][:,:,72])
ax, hm2 = heatmap(fig[2,1], x, y, dbscan[:,:,72])
fig

DBSCAN_BUG.jld2.zip Screenshot 2023-02-23 at 14 28 54

Expected result

There should be a yellow blob in the second heatmap, corresponding to the identified (very obvious) cluster

cstjean commented 1 year ago

Isn't that a problem with the scikit-learn library? ScikitLearn.jl is just an interface to the python scikit-learn. If so, I would encourage you to translate your example to Python and post it there.

nickkeepfer commented 1 year ago

Hmm, yes probably, I'll see if im able to translate it

nickkeepfer commented 1 year ago

I'm actually not sure it is an issue with scikit-learn, as it works just fine using it natively in python, (see below):

import numpy as np
import h5py
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
import matplotlib.colors as colors

# Load data
f = h5py.File("DBSCAN_BUG.jld2", "r")
x, y, z = f["x"][:], f["y"][:], f["z"][:]

# Format data such that each voxel is given as an (x,y,z) coordinate
X = np.repeat(x, len(y) * len(z)).reshape(len(x), len(y), len(z), order='F') + 2 * np.max(x)
Y = np.repeat(y, len(x) * len(z)).reshape(len(y), len(x), len(z), order='C') + 2 * np.max(y)
Z = np.repeat(z, len(x) * len(y)).reshape(len(z), len(x), len(y), order='F').transpose((1, 2, 0)) + 2 * np.max(z)
dat = np.vstack((X.ravel('F'), Y.ravel('F'), Z.ravel('F'))).T

# Perform DBSCAN where points are weighted by density array
decomp = DBSCAN(eps=np.abs(x[0] - x[1]), min_samples=1).fit_predict(dat, sample_weight=f["dens"][:].ravel())
dbscan = np.reshape(np.where(decomp != -1, 1, 0), f["dens"].shape)

# Plot DBSCAN results alongside the density
fig, axs = plt.subplots(2, 1)
hm1 = axs[0].imshow(f["dens"][72, :, :], norm=colors.LogNorm())
hm2 = axs[1].imshow(dbscan[72, :, :], cmap='binary')
plt.show()

Figure_1

cstjean commented 1 year ago

I'm not super-familiar with DBScan, I scanned your code and nothing looked obviously wrong. Beware that

dbscan = replace!(reshape(decomp,size(f["dens"])),-1=>0)

reshape is a view, so this line is also mutating decomp. But that shouldn't modify the outcome.

Beyond that, I can't offer advice other than: try to figure out what's different in Python and Julia. Ultimately, it's the same library doing the work, so presumably the inputs (or the plotting) is different.