albertbup / deep-belief-network

A Python implementation of Deep Belief Networks built upon NumPy and TensorFlow with scikit-learn compatibility
MIT License
481 stars 212 forks source link

cannot import name 'SupervisedDBNClassification' #25

Closed deerdodo closed 6 years ago

deerdodo commented 6 years ago

i tried this code for my accelerometer signal dataset as follow

import numpy as np

np.random.seed(1337)  # for reproducibility
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics.classification import accuracy_score
from dbn import SupervisedDBNClassification

# Loading dataset
train = np.loadtxt("TrainDatasetFinal.txt", delimiter=",")
test = np.loadtxt("testDatasetFinal.txt", delimiter=",")

y_train = train[:,7]
y_test = test[:,7]

train_spec = train[:,6]
test_spec = test[:,6]

# Training
classifier = SupervisedDBNClassification(hidden_layers_structure=[256, 256],
                                         learning_rate_rbm=0.05,
                                         learning_rate=0.1,
                                         n_epochs_rbm=10,
                                         n_iter_backprop=100,
                                         batch_size=32,
                                         activation_function='relu',
                                         dropout_p=0.2)
classifier.fit(train_spec, y_train)

but it gets me the following error

Traceback (most recent call last):
  File "G:\dbn.py", line 35, in <module>
    from dbn import SupervisedDBNClassification
  File "G:\dbn.py", line 35, in <module>
    from dbn import SupervisedDBNClassification
ImportError: cannot import name 'SupervisedDBNClassification'

this is a sample from my dataset

(Patient Number, time in millisecond, accelerometer x-axis,y-axis, z-axis,magnitude, spectrogram,label (0 or 1))

1,15,70,39,-970,947321,596768455815000,0
1,31,70,39,-970,947321,612882670787000,0
1,46,60,49,-960,927601,602179976392000,0
1,62,60,49,-960,927601,808020878060000,0
1,78,50,39,-960,925621,726154800929000,0

in the dataset i am using the only the spectrogram as input feature and the label (0 or 1) as the output the total traing samples is 1,415,684