Closed wmwv closed 9 months ago
I cannot replicate this error on my end.
What do you get for the full array of new_model.predict
? I get
In [11]: new_model.predict(time, mag, magerr, convert=True, zp=22, apply_weights=True)
Out[11]:
array([[0.00000000e+000, 4.37237005e-228],
[1.00000000e+000, 0.00000000e+000],
[2.00000000e+000, 1.00000000e+000],
[3.00000000e+000, 0.00000000e+000]])
So it identifies this lightcurve as being of a class 2 ("ML") event:
CV -------------> 0
LPV -------------> 1
ML -------------> 2
RRLYR -------------> 3
which certainly seems plausible based on the lightcurve
And yes, this is with scikitlearn=1.1.1
(microlia) [wmwv@nb-wmwv MicroLIA]$ pip list
Package Version
----------------------------- -----------
absl-py 1.4.0
alembic 1.11.1
astroML 1.0.2.post1
astropy 5.0.4
asttokens 2.2.1
astunparse 1.6.3
backcall 0.2.0
backports.functools-lru-cache 1.6.5
Boruta 0.3
BorutaShap 1.0.16
cachetools 5.3.1
certifi 2023.5.7
charset-normalizer 3.1.0
cloudpickle 2.2.1
cmaes 0.9.1
colorlog 6.7.0
comm 0.1.3
cycler 0.11.0
debugpy 1.6.7
decorator 5.1.1
dill 0.3.6
executing 1.2.0
flatbuffers 1.12
fonttools 4.40.0
gast 0.4.0
gatspy 0.3
google-auth 2.20.0
google-auth-oauthlib 0.4.6
google-pasta 0.2.0
greenlet 2.0.2
grpcio 1.56.0
h5py 3.9.0
idna 3.4
imbalanced-learn 0.10.1
importlib-metadata 6.7.0
ipykernel 6.23.3
ipython 8.14.0
jedi 0.18.2
joblib 1.2.0
jupyter_client 8.3.0
jupyter_core 5.3.1
keras 2.9.0
Keras-Preprocessing 1.1.2
kiwisolver 1.4.4
libclang 16.0.0
llvmlite 0.40.1
Mako 1.2.4
Markdown 3.4.3
MarkupSafe 2.1.3
matplotlib 3.5.1
matplotlib-inline 0.1.6
MicroLIA 2.6.0
nest-asyncio 1.5.6
numba 0.57.1
numpy 1.22.4
oauthlib 3.2.2
opencv-python 4.7.0.68
opt-einsum 3.3.0
optuna 3.1.0
packaging 23.1
pandas 1.4.1
parso 0.8.3
patsy 0.5.3
PeakUtils 1.3.3
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.5.0
pip 23.1.2
platformdirs 3.8.0
progress 1.6
prompt-toolkit 3.0.38
protobuf 3.19.6
psutil 5.9.5
ptyprocess 0.7.0
pure-eval 0.2.2
pyaml 23.5.9
pyasn1 0.5.0
pyasn1-modules 0.3.0
pyerfa 2.0.0.3
Pygments 2.15.1
pyparsing 3.1.0
python-dateutil 2.8.2
pytz 2023.3
PyYAML 6.0
pyzmq 25.1.0
requests 2.31.0
requests-oauthlib 1.3.1
rsa 4.9
scikit-learn 1.1.1
scikit-optimize 0.9.0
scikit-plot 0.3.7
scipy 1.7.3
seaborn 0.12.2
setuptools 67.7.2
shap 0.41.0
six 1.16.0
slicer 0.0.7
SQLAlchemy 2.0.16
stack-data 0.6.2
statsmodels 0.14.0
tensorboard 2.9.1
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow 2.9.3
tensorflow-estimator 2.9.0
tensorflow-io-gcs-filesystem 0.32.0
termcolor 2.3.0
threadpoolctl 3.1.0
tornado 6.3.2
tqdm 4.65.0
traitlets 5.9.0
typing_extensions 4.6.3
urllib3 1.26.16
wcwidth 0.2.6
Werkzeug 2.3.6
wheel 0.40.0
wrapt 1.15.0
xgboost 1.6.1
zipp 3.15.0
Issue identified, problem due to the indexing in the test file:
new_model.predict(time, mag, magerr, convert=True, zp=22, apply_weights=True)[:,1][3]
The index of '3' corresponds to ML when using a simulated training set in which there are 5 classes, however, in the test folder I included a csv file containing the training set from real OGLE IV lightcurves (4 classes). This was an oversight on my end, as the test_model_xgb test folder contains a model made with simulations using OGLE II cadence. For consistency, I have replaced the model inside test_model_xgb with the correct one trained with real OGLE IV lightcurves as per the provided csv file. This model is the one created if you follow the OGLE IV example in the documentation.
Furthermore, the test_ogle_lc.dat file is an OGLE II lightcurve, this has thus been replaced with a real OGLE IV lightcurve, and the test_classifier file has been updated to reflect the true values. You will note that the base neural network yields poor predictions, this is expected as the hidden layers are not properly configured in this case.
When I run
[wmwv@nb-wmwv test]$ python test_classifier.py
against
7b1a31a
, I get a failure at test_classifier.py#807b1a31a
version 2.6.0 or is that3e567eb
? I might suggest considering git annotated tagging release versions to make this clearer.Full output: