ageron / handson-ml3

A series of Jupyter notebooks that walk you through the fundamentals of Machine Learning and Deep Learning in Python using Scikit-Learn, Keras and TensorFlow 2.
Apache License 2.0
7.94k stars 3.19k forks source link

[BUG] Running cell 5 of 05_support_vector_machines.ipynb notebook gives error #67

Open vasili111 opened 1 year ago

vasili111 commented 1 year ago

I am running cells from top till cell 5 and when running cell 5 with this code:

# extra code – this cell generates and saves Figure 5–1

import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
from sklearn import datasets

iris = datasets.load_iris(as_frame=True)
X = iris.data[["petal length (cm)", "petal width (cm)"]].values
y = iris.target

setosa_or_versicolor = (y == 0) | (y == 1)
X = X[setosa_or_versicolor]
y = y[setosa_or_versicolor]

# SVM Classifier model
svm_clf = SVC(kernel="linear", C=float("inf"))
svm_clf.fit(X, y)

# Bad models
x0 = np.linspace(0, 5.5, 200)
pred_1 = 5 * x0 - 20
pred_2 = x0 - 1.8
pred_3 = 0.1 * x0 + 0.5

def plot_svc_decision_boundary(svm_clf, xmin, xmax):
    w = svm_clf.coef_[0]
    b = svm_clf.intercept_[0]

    # At the decision boundary, w0*x0 + w1*x1 + b = 0
    # => x1 = -w0/w1 * x0 - b/w1
    x0 = np.linspace(xmin, xmax, 200)
    decision_boundary = -w[0] / w[1] * x0 - b / w[1]

    margin = 1/w[1]
    gutter_up = decision_boundary + margin
    gutter_down = decision_boundary - margin
    svs = svm_clf.support_vectors_

    plt.plot(x0, decision_boundary, "k-", linewidth=2, zorder=-2)
    plt.plot(x0, gutter_up, "k--", linewidth=2, zorder=-2)
    plt.plot(x0, gutter_down, "k--", linewidth=2, zorder=-2)
    plt.scatter(svs[:, 0], svs[:, 1], s=180, facecolors='#AAA',
                zorder=-1)

fig, axes = plt.subplots(ncols=2, figsize=(10, 2.7), sharey=True)

plt.sca(axes[0])
plt.plot(x0, pred_1, "g--", linewidth=2)
plt.plot(x0, pred_2, "m-", linewidth=2)
plt.plot(x0, pred_3, "r-", linewidth=2)
plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", label="Iris versicolor")
plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", label="Iris setosa")
plt.xlabel("Petal length")
plt.ylabel("Petal width")
plt.legend(loc="upper left")
plt.axis([0, 5.5, 0, 2])
plt.gca().set_aspect("equal")
plt.grid()

plt.sca(axes[1])
plot_svc_decision_boundary(svm_clf, 0, 5.5)
plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs")
plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo")
plt.xlabel("Petal length")
plt.axis([0, 5.5, 0, 2])
plt.gca().set_aspect("equal")
plt.grid()

save_fig("large_margin_classification_plot")
plt.show()

I am getting error:

---------------------------------------------------------------------------
InvalidParameterError                     Traceback (most recent call last)
Cell In[5], line 18
     16 # SVM Classifier model
     17 svm_clf = SVC(kernel="linear", C=float("inf"))
---> 18 svm_clf.fit(X, y)
     20 # Bad models
     21 x0 = np.linspace(0, 5.5, 200)

File ~\anaconda3\envs\ml_1\Lib\site-packages\sklearn\svm\_base.py:180, in BaseLibSVM.fit(self, X, y, sample_weight)
    147 def fit(self, X, y, sample_weight=None):
    148     """Fit the SVM model according to the given training data.
    149 
    150     Parameters
   (...)
    178     matrices as input.
    179     """
--> 180     self._validate_params()
    182     rnd = check_random_state(self.random_state)
    184     sparse = sp.isspmatrix(X)

File ~\anaconda3\envs\ml_1\Lib\site-packages\sklearn\base.py:581, in BaseEstimator._validate_params(self)
    573 def _validate_params(self):
    574     """Validate types and values of constructor parameters
    575 
    576     The expected type and values must be defined in the `_parameter_constraints`
   (...)
    579     accepted constraints.
    580     """
--> 581     validate_parameter_constraints(
    582         self._parameter_constraints,
    583         self.get_params(deep=False),
    584         caller_name=self.__class__.__name__,
    585     )

File ~\anaconda3\envs\ml_1\Lib\site-packages\sklearn\utils\_param_validation.py:97, in validate_parameter_constraints(parameter_constraints, params, caller_name)
     91 else:
     92     constraints_str = (
     93         f"{', '.join([str(c) for c in constraints[:-1]])} or"
     94         f" {constraints[-1]}"
     95     )
---> 97 raise InvalidParameterError(
     98     f"The {param_name!r} parameter of {caller_name} must be"
     99     f" {constraints_str}. Got {param_val!r} instead."
    100 )

InvalidParameterError: The 'C' parameter of SVC must be a float in the range (0.0, inf). Got inf instead.

How to fix it?

vasili111 commented 1 year ago

Also cell 6 with code:

# extra code – this cell generates and saves Figure 5–2

from sklearn.preprocessing import StandardScaler

Xs = np.array([[1, 50], [5, 20], [3, 80], [5, 60]]).astype(np.float64)
ys = np.array([0, 0, 1, 1])
svm_clf = SVC(kernel="linear", C=100).fit(Xs, ys)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(Xs)
svm_clf_scaled = SVC(kernel="linear", C=100).fit(X_scaled, ys)

plt.figure(figsize=(9, 2.7))
plt.subplot(121)
plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], "bo")
plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], "ms")
plot_svc_decision_boundary(svm_clf, 0, 6)
plt.xlabel("$x_0$")
plt.ylabel("$x_1$    ", rotation=0)
plt.title("Unscaled")
plt.axis([0, 6, 0, 90])
plt.grid()

plt.subplot(122)
plt.plot(X_scaled[:, 0][ys==1], X_scaled[:, 1][ys==1], "bo")
plt.plot(X_scaled[:, 0][ys==0], X_scaled[:, 1][ys==0], "ms")
plot_svc_decision_boundary(svm_clf_scaled, -2, 2)
plt.xlabel("$x'_0$")
plt.ylabel("$x'_1$  ", rotation=0)
plt.title("Scaled")
plt.axis([-2, 2, -2, 2])
plt.grid()

save_fig("sensitivity_to_feature_scales_plot")
plt.show()

gives this error:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 17
     15 plt.plot(Xs[:, 0][ys==1], Xs[:, 1][ys==1], "bo")
     16 plt.plot(Xs[:, 0][ys==0], Xs[:, 1][ys==0], "ms")
---> 17 plot_svc_decision_boundary(svm_clf, 0, 6)
     18 plt.xlabel("$x_0$")
     19 plt.ylabel("$x_1$    ", rotation=0)

NameError: name 'plot_svc_decision_boundary' is not defined

Similar error is recived when running cells: 7 and 11.