This pull request aims to solve the below issue when sampling in a different classification datasets, and this issue occurs because the y_null and class_labels are hardcoded in the sample.py file.
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`
This pull request aims to solve the below issue when sampling in a different classification datasets, and this issue occurs because the
y_null
andclass_labels
are hardcoded in the sample.py file.