Hoppa till innehåll

Support Vector Machine (SVM) klassificering i Python

Jag kommer att implementera en SVM-algoritm (Support Vector Machine) för klassificering i denna handledning. Jag kommer att visualisera datauppsättningen, hitta de bästa hyperparametrarna att använda, träna en modell och utvärdera resultaten. Support Vector Machine är en snabb algoritm som kan användas för att klassificera datauppsättningar med linjär separation, SVM:s kan vara hjälpfulla vid textkategorisering.

Support Vector Machine kan användas för binära klassificeringsproblem och för flerklassproblem. Support Vector Machine är en linjär metod och fungerar därför inte bra för datauppsättningar som har en icke-linjär struktur. Support Vector Machine kan användas med icke-linjär data om man tillämpar kärntricket. Support Vector Machine försöker konstruera centrerade hyperplan mellan klasser, algoritmen vill hitta hyperplan som har den högsta marginalen mellan grupper av datapunkter. Datapunkterna närmast hyperplanet kallas supportvektorer.

SVM, hyperplan och supportvektorer

Support Vector Machine-algoritmen är enkel att använda, den är snabb och den resulterande modellen tar inte mycket hårddiskutrymme i anspråk. Scikit-learn har tre modeller för SVM som skiljer sig åt i implementeringen: SVC, NuSVC och LinearSVC. SVC är baserad på libsvm, passningstiden skalas åtminstone kvadratiskt med antalet datapunkter. NuSVC, liknar SVC men använder en parameter för att kontrollera antalet supportvektorer. LinearSVC liknar SVC, men den använder en linjär kärna och implementeras i termer av liblinear snarare än libsvm. Jag kommer att använda LinearSVC eftersom den skalar bäst till ett stora datauppsättningar.

Datauppsättning och bibliotek

Jag kommer att använda datauppsättningen Iris (ladda ner) i den här handledningen. Iris-uppsättningen består av 150 blommor, varje blomma har fyra indatavärden och ett målvärde. Jag använder också följande bibliotek: pandas, joblib, numpy, matplotlib och scikit-learn.

Python-modul

Jag har inkluderat all kod i en fil, ett projekt består normalt av många filer (moduler). Du kan skapa namnområden genom att placera filer i mappar och du importerar en fil med dess namnområden plus dess filnamn. En fil med namnet common.py i mappen annytab/learning importeras som import annytab.learn.common. Jag kommer att förklara mer om koden i avsnitten nedan.

  1. # Import libraries
  2. import pandas
  3. import joblib
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import sklearn.model_selection
  7. import sklearn.svm
  8. import sklearn.metrics
  9. import sklearn.pipeline
  10. # Visualize data set
  11. def visualize_dataset(ds):
  12. # Print first 5 rows in data set
  13. print('--- First 5 rows ---\n')
  14. print(ds.head())
  15. # Print the shape
  16. print('\n--- Shape of data set ---\n')
  17. print(ds.shape)
  18. # Print class distribution
  19. print('\n--- Class distribution ---\n')
  20. print(ds.groupby('species').size())
  21. # Box plots
  22. plt.figure(figsize = (12, 8))
  23. ds.boxplot()
  24. #plt.show()
  25. plt.savefig('plots\\iris-boxplots.png')
  26. plt.close()
  27. # Scatter plots (4 subplots in 1 figure)
  28. figure = plt.figure(figsize = (12, 8))
  29. grouped_dataset = ds.groupby('species')
  30. values = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
  31. for i, value in enumerate(values):
  32. plt.subplot(2, 2, i + 1) # 2 rows and 2 columns
  33. for name, group in grouped_dataset:
  34. plt.scatter(group.index, ds[value][group.index], label=name)
  35. plt.ylabel(value)
  36. plt.xlabel('index')
  37. plt.legend()
  38. #plt.show()
  39. plt.savefig('plots\\iris-scatterplots.png')
  40. # Train and evaluate
  41. def train_and_evaluate(X, Y):
  42. # Create a model
  43. model = sklearn.svm.LinearSVC(penalty='l1', loss='squared_hinge', dual=False, tol=0.0001, C=0.4, multi_class='ovr',
  44. fit_intercept=True, intercept_scaling=1, class_weight=None, verbose=0, random_state=None, max_iter=10000)
  45. # Train the model on the whole data set
  46. model.fit(X, Y)
  47. # Save the model (Make sure that the folder exists)
  48. joblib.dump(model, 'models\\svm.jbl')
  49. # Evaluate on training data
  50. print('\n-- Training data --\n')
  51. predictions = model.predict(X)
  52. accuracy = sklearn.metrics.accuracy_score(Y, predictions)
  53. print('Accuracy: {0:.2f}'.format(accuracy * 100.0))
  54. print('Classification Report:')
  55. print(sklearn.metrics.classification_report(Y, predictions))
  56. print('Confusion Matrix:')
  57. print(sklearn.metrics.confusion_matrix(Y, predictions))
  58. print('')
  59. # Evaluate with 10-fold CV
  60. print('\n-- 10-fold CV --\n')
  61. predictions = sklearn.model_selection.cross_val_predict(model, X, Y, cv=10)
  62. accuracy = sklearn.metrics.accuracy_score(Y, predictions)
  63. print('Accuracy: {0:.2f}'.format(accuracy * 100.0))
  64. print('Classification Report:')
  65. print(sklearn.metrics.classification_report(Y, predictions))
  66. print('Confusion Matrix:')
  67. print(sklearn.metrics.confusion_matrix(Y, predictions))
  68. # Perform a grid search to find the best hyperparameters
  69. def grid_search(X, Y):
  70. # Create a pipeline
  71. clf_pipeline = sklearn.pipeline.Pipeline([
  72. ('m', sklearn.svm.LinearSVC(loss='squared_hinge', tol=0.0001, multi_class='ovr', dual=False, class_weight=None, verbose=0, random_state=None, max_iter=10000))
  73. ])
  74. # Set parameters (name in pipeline + name of parameter)
  75. parameters = {
  76. 'm__penalty': ('l1', 'l2'),
  77. 'm__C': (0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
  78. 'm__fit_intercept': (False, True),
  79. 'm__intercept_scaling': (0.5, 1, 2)
  80. }
  81. # Create a grid search classifier
  82. gs_classifier = sklearn.model_selection.GridSearchCV(clf_pipeline, parameters, cv=10, iid=False, n_jobs=2, scoring='accuracy', verbose=1)
  83. # Start a search (Warning: can take a long time if the whole dataset is used)
  84. gs_classifier = gs_classifier.fit(X, Y)
  85. # Print results
  86. print('---- Results ----')
  87. print('Best score: ' + str(gs_classifier.best_score_))
  88. for name in sorted(parameters.keys()):
  89. print('{0}: {1}'.format(name, gs_classifier.best_params_[name]))
  90. # Predict and evaluate on test data
  91. def predict_and_evaluate(X, Y):
  92. # Load the model
  93. model = joblib.load('models\\svm.jbl')
  94. # Make predictions
  95. predictions = model.predict(X)
  96. # Print results
  97. print('\n---- Results ----')
  98. for i in range(len(predictions)):
  99. print('Input: {0}, Predicted: {1}, Actual: {2}'.format(X[i], predictions[i], Y[i]))
  100. accuracy = sklearn.metrics.accuracy_score(Y, predictions)
  101. print('\nAccuracy: {0:.2f}'.format(accuracy * 100.0))
  102. print('\nClassification Report:')
  103. print(sklearn.metrics.classification_report(Y, predictions))
  104. print('Confusion Matrix:')
  105. print(sklearn.metrics.confusion_matrix(Y, predictions))
  106. # The main entry point for this module
  107. def main():
  108. # Load data set (includes header values)
  109. dataset = pandas.read_csv('files\\iris.csv')
  110. # Visualize data set
  111. visualize_dataset(dataset)
  112. # Slice data set in values and targets (2D-array)
  113. X = dataset.values[:,0:4]
  114. Y = dataset.values[:,4]
  115. # Split data set in train and test (use random state to get the same split every time, and stratify to keep balance)
  116. X_train, X_test, Y_train, Y_test = sklearn.model_selection.train_test_split(X, Y, test_size=0.2, random_state=1, stratify=Y)
  117. # Make sure that data still is balanced
  118. print('\n--- Class balance ---\n')
  119. print(np.unique(Y_train, return_counts=True))
  120. print(np.unique(Y_test, return_counts=True))
  121. # Perform a grid search
  122. #grid_search(X, Y)
  123. # Train and evaluate
  124. #train_and_evaluate(X_train, Y_train)
  125. # Predict on test set
  126. predict_and_evaluate(X_test, Y_test)
  127. # Tell python to run main method
  128. if __name__ == "__main__": main()

Läs in och visualisera datauppsättningen

Datauppsättningen laddas med pandas genom att använda en relativ sökväg, använd en absolut sökväg om dina filer lagras utanför projektet. Vi vill visualisera datauppsättningen för att se till att den är balanserad och vi vill lära oss mer om datauppsättningen. Det är viktigt att ha en balanserad datauppsättning vid klassificering, varje klass tränas lika frekvent med en balanserad träningsuppsättning. Vi kan plotta en datauppsättning för att hitta mönster, ta bort extremvärden och för att besluta om de lämpligaste algoritmerna att använda.

  1. # Load data set (includes header values)
  2. dataset = pandas.read_csv('files\\iris.csv')
  3. # Visualize data set
  4. visualize_dataset(dataset)
  5. --- First 5 rows ---
  6. sepal_length sepal_width petal_length petal_width species
  7. 0 5.1 3.5 1.4 0.2 Iris-setosa
  8. 1 4.9 3.0 1.4 0.2 Iris-setosa
  9. 2 4.7 3.2 1.3 0.2 Iris-setosa
  10. 3 4.6 3.1 1.5 0.2 Iris-setosa
  11. 4 5.0 3.6 1.4 0.2 Iris-setosa
  12. --- Shape of dataset ---
  13. (150, 5)
  14. --- Class distribution ---
  15. species
  16. Iris-setosa 50
  17. Iris-versicolor 50
  18. Iris-virginica 50
  19. dtype: int64
Iris spridningsdiagram

Dela upp datauppsättningen

Jag måste först dela upp värden i datauppsättningen för att få indata (X) och utdata (Y), de första 4 kolumnerna är indata och den sista kolumnen utgör målvärdet. Jag delar upp datauppsättningen i en träningsuppsättning och en testuppsättning, 80 % är för träning och 20 % för test. Jag vill se till att datauppsättningarna fortfarande är balanserade efter denna delning och jag använder därför en stratify-parameter.

  1. # Slice data set in values and targets (2D-array)
  2. X = dataset.values[:,0:4]
  3. Y = dataset.values[:,4]
  4. # Split data set in train and test (use random state to get the same split every time, and stratify to keep balance)
  5. X_train, X_test, Y_train, Y_test = sklearn.model_selection.train_test_split(X, Y, test_size=0.2, random_state=1, stratify=Y)
  6. # Make sure that data still is balanced
  7. print('\n--- Class balance ---\n')
  8. print(np.unique(Y_train, return_counts=True))
  9. print(np.unique(Y_test, return_counts=True))

Baslinjeprestanda

Vår datauppsättning har 150 blommor och 50 blommor i varje klass, vår träningsuppsättning har samma balans. En slumpvis förutsägelse kommer att vara korrekt i 33% (50/150) av alla fall och vår modell måste ha en noggrannhet som är bättre än 33 % för att vara användbar.

Rutnätssökning

Jag gör en rutnätsökning för att hitta de bästa parametrarna för träning. En rutnätsökning kan ta lång tid att utföra på stora datauppsättningar, men det är antagligen snabbare jämfört med en manuell process. Resultatet från denna process visas nedan och jag kommer att använda dessa parametrar när jag tränar modellen.

  1. # Perform a grid search
  2. grid_search(X, Y)
  3. Fitting 10 folds for each of 108 candidates, totalling 1080 fits
  4. [Parallel(n_jobs=2)]: Using backend LokyBackend with 2 concurrent workers.
  5. [Parallel(n_jobs=2)]: Done 968 tasks | elapsed: 2.9s
  6. [Parallel(n_jobs=2)]: Done 1080 out of 1080 | elapsed: 3.1s finished
  7. ---- Results ----
  8. Best score: 0.9666666666666668
  9. m__C: 0.4
  10. m__fit_intercept: True
  11. m__intercept_scaling: 1
  12. m__penalty: l1

Träning och utvärdering

Jag tränar modellen genom att använda parametrarna från rutnätsökningen och sparar modellen till en fil med joblib. Utvärderingen görs på träningsuppsättningen och med korsvalidering. Korsvalideringsutvärderingen ger en antydan om modellens generaliseringsprestanda. Jag hade 95 % exakthet på träningsdata och 95 % exakthet med tiofaldig korsvalidering.

  1. # Train and evaluate
  2. train_and_evaluate(X_train, Y_train)
  3. -- Training data --
  4. Accuracy: 95.00
  5. Classification Report:
  6. precision recall f1-score support
  7. Iris-setosa 1.00 1.00 1.00 40
  8. Iris-versicolor 0.95 0.90 0.92 40
  9. Iris-virginica 0.90 0.95 0.93 40
  10. accuracy 0.95 120
  11. macro avg 0.95 0.95 0.95 120
  12. weighted avg 0.95 0.95 0.95 120
  13. Confusion Matrix:
  14. [[40 0 0]
  15. [ 0 36 4]
  16. [ 0 2 38]]
  17. -- 10-fold CV --
  18. Accuracy: 95.00
  19. Classification Report:
  20. precision recall f1-score support
  21. Iris-setosa 1.00 1.00 1.00 40
  22. Iris-versicolor 0.93 0.93 0.93 40
  23. Iris-virginica 0.93 0.93 0.93 40
  24. accuracy 0.95 120
  25. macro avg 0.95 0.95 0.95 120
  26. weighted avg 0.95 0.95 0.95 120
  27. Confusion Matrix:
  28. [[40 0 0]
  29. [ 0 37 3]
  30. [ 0 3 37]]

Test och utvärdering

Det sista steget i denna process är att göra förutsägelser och utvärdera prestationer avseende testdata. Jag läser in modellen, gör förutsägelser och skriver ut resultaten. X-variabeln är en 2D-array, om du vill göra en förutsägelse för en blomma måste du ange indata så här: X = np.array ([[7.3, 2.9, 6.3, 1.8]]).

  1. # Predict on test set
  2. predict_and_evaluate(X_test, Y_test)
  3. ---- Results ----
  4. Input: [7.3 2.9 6.3 1.8], Predicted: Iris-virginica, Actual: Iris-virginica
  5. Input: [4.9 3.1 1.5 0.1], Predicted: Iris-setosa, Actual: Iris-setosa
  6. Input: [5.1 2.5 3.0 1.1], Predicted: Iris-versicolor, Actual: Iris-versicolor
  7. Input: [4.8 3.4 1.6 0.2], Predicted: Iris-setosa, Actual: Iris-setosa
  8. Input: [5.0 3.5 1.6 0.6], Predicted: Iris-setosa, Actual: Iris-setosa
  9. Input: [5.1 3.5 1.4 0.2], Predicted: Iris-setosa, Actual: Iris-setosa
  10. Input: [6.2 3.4 5.4 2.3], Predicted: Iris-virginica, Actual: Iris-virginica
  11. Input: [6.4 2.7 5.3 1.9], Predicted: Iris-virginica, Actual: Iris-virginica
  12. Input: [5.6 2.8 4.9 2.0], Predicted: Iris-virginica, Actual: Iris-virginica
  13. Input: [6.8 2.8 4.8 1.4], Predicted: Iris-versicolor, Actual: Iris-versicolor
  14. Input: [5.4 3.9 1.3 0.4], Predicted: Iris-setosa, Actual: Iris-setosa
  15. Input: [5.5 2.3 4.0 1.3], Predicted: Iris-versicolor, Actual: Iris-versicolor
  16. Input: [6.8 3.0 5.5 2.1], Predicted: Iris-virginica, Actual: Iris-virginica
  17. Input: [6.0 2.2 4.0 1.0], Predicted: Iris-versicolor, Actual: Iris-versicolor
  18. Input: [5.7 2.5 5.0 2.0], Predicted: Iris-virginica, Actual: Iris-virginica
  19. Input: [5.7 4.4 1.5 0.4], Predicted: Iris-setosa, Actual: Iris-setosa
  20. Input: [7.1 3.0 5.9 2.1], Predicted: Iris-virginica, Actual: Iris-virginica
  21. Input: [6.1 2.8 4.0 1.3], Predicted: Iris-versicolor, Actual: Iris-versicolor
  22. Input: [4.9 2.4 3.3 1.0], Predicted: Iris-versicolor, Actual: Iris-versicolor
  23. Input: [6.1 3.0 4.9 1.8], Predicted: Iris-virginica, Actual: Iris-virginica
  24. Input: [6.4 2.9 4.3 1.3], Predicted: Iris-versicolor, Actual: Iris-versicolor
  25. Input: [5.6 3.0 4.5 1.5], Predicted: Iris-versicolor, Actual: Iris-versicolor
  26. Input: [4.9 3.1 1.5 0.1], Predicted: Iris-setosa, Actual: Iris-setosa
  27. Input: [4.4 2.9 1.4 0.2], Predicted: Iris-setosa, Actual: Iris-setosa
  28. Input: [6.5 3.0 5.2 2.0], Predicted: Iris-virginica, Actual: Iris-virginica
  29. Input: [4.9 2.5 4.5 1.7], Predicted: Iris-virginica, Actual: Iris-virginica
  30. Input: [5.4 3.9 1.7 0.4], Predicted: Iris-setosa, Actual: Iris-setosa
  31. Input: [4.8 3.0 1.4 0.1], Predicted: Iris-setosa, Actual: Iris-setosa
  32. Input: [6.3 3.3 4.7 1.6], Predicted: Iris-versicolor, Actual: Iris-versicolor
  33. Input: [6.5 2.8 4.6 1.5], Predicted: Iris-versicolor, Actual: Iris-versicolor
  34. Accuracy: 100.00
  35. Classification Report:
  36. precision recall f1-score support
  37. Iris-setosa 1.00 1.00 1.00 10
  38. Iris-versicolor 1.00 1.00 1.00 10
  39. Iris-virginica 1.00 1.00 1.00 10
  40. accuracy 1.00 30
  41. macro avg 1.00 1.00 1.00 30
  42. weighted avg 1.00 1.00 1.00 30
  43. Confusion Matrix:
  44. [[10 0 0]
  45. [ 0 10 0]
  46. [ 0 0 10]]
Etiketter:

Lämna ett svar

Din e-postadress kommer inte publiceras. Obligatoriska fält är märkta *