facebookresearch / LaMCTS

The release codes of LA-MCTS with its application to Neural Architecture Search.
Other
463 stars 71 forks source link

Use the K-Means labels or the SVM prediction labels to separate the samples into kid nodes? #14

Closed BigTailFox closed 3 years ago

BigTailFox commented 3 years ago

Hello, I am now re-implementing the LA-MCTS according to the paper. My question is: I use the labels predicted by SVM to decide which nodes of the parent should go to the good kid or bad kid. However, the SVM predicted labels are NOT always the same as those of K-Means. In some extreme cases, the mean of the bad kid was bigger than that of the good kid, due to the SVM classification error. I checked the code of yours (as below), I found K-Means labels were used, instead of the SVM labels. I don't understand why should use the clustering labels, they don't represent the splitting, right?

def split_data(self):
        good_samples = []
        bad_samples = []
        train_good_samples = []
        train_bad_samples = []
        if len(self.samples) == 0:
            return good_samples, bad_samples

        plabel = self.learn_clusters()
        self.learn_boundary(plabel)

        for idx in range(0, len(plabel)):
            if plabel[idx] == 0:
                # ensure the consistency
                assert self.samples[idx][-1] - self.fX[idx] <= 1
                good_samples.append(self.samples[idx])
                train_good_samples.append(self.X[idx])
            else:
                bad_samples.append(self.samples[idx])
                train_bad_samples.append(self.X[idx])

        train_good_samples = np.array(train_good_samples)
        train_bad_samples = np.array(train_bad_samples)

        assert len(good_samples) + len(bad_samples) == len(self.samples)

        return good_samples, bad_samples
yuandong-tian commented 3 years ago

It is a good point. In a newer internal version, @linnanwang (the first author) has used the SVM label. So we should fix it later.

linnanwang commented 3 years ago

"I use the labels predicted by SVM to decide which nodes of the parent should go to the good kid or bad kid."

What are the input features? Are you using the SVM as a regressor?

"In some extreme cases, the mean of the bad kid was bigger than that of the good kid, due to the SVM classification error." Then this will be wrong.

Cluster labels are feeding SVM to learn a boundary. Please note SVM can be used for regression and classification, here we use SVM for classification to learn a boundary, and use SVM to label the existing samples into 2 classes with both x and f(x) considered, i.e. the input feature vector to k-means.

BigTailFox commented 3 years ago

Thank you for your replying, I tested my version on an artificial function. Input features are just 2-dim scalar between -10 and 10. I use SVC from sklearn, I guess I am using SVM as a classifier. I concatenate X and fX to a new vector to train the k-means, but only the [X, cluster_plabels] to train the SVC. Should I train the SVC with [X, fX] too?

linnanwang commented 3 years ago

No problem, you will need [X] and their corresponding labels to train SVC. If you did not K-means to generate these labels, how do you train SVC? I suggest you take a careful read at this page: https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html.

Also a quick recap of the work flow: [x, f(x)]-> generate labels indicating good or bad l(x) using k-means. [x, l(x)]-> train svc to get the boundary to partition search space x.

BigTailFox commented 3 years ago

OK, I understand. Then we should tolerate mean_good_kid < mean_bad_kid, or enlarge the punishment of SVC until it can split the samples with mean_good_kid < mean_bad_kid. Otherwise, it would trigger an assertation fault in the original code. However, I think k-means a little bit weak for it only forms circle-like clusters. Normalizing X and fX at the same time also seems a little bit wired. I am trying other clustering methods, or just learning a regressor and let the points with predicted fx > f_bar go to good nodes. But the framework of LA-MCTS is gorgeous work.

linnanwang commented 3 years ago

Thank you for your kind words.

In general, you want to make sure mean_good_kid < mean_bad_kid, otherwise it might indicate an error in your code.