cds1011/classification/classification_mnist_demo.py

119 lines
3.9 KiB
Python

import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.multiclass import OneVsOneClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_score, recall_score, f1_score
# Datensatz herunterladen
print("✅ Datensatz herunterladen")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
mnist.keys()
X, y = mnist["data"], mnist["target"]
X.shape
# Ziffer aus dem Datensatz: 5
print("✅ Ziffer aus dem Datensatz: 5")
some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.show()
# Ziffer aus dem Datensatz: 0
print("✅ Ziffer aus dem Datensatz: 0")
some_other_digit = X[1]
some_other_digit_image = some_other_digit.reshape(28, 28)
plt.imshow(some_other_digit_image, cmap=mpl.cm.binary)
plt.show()
# Label
print("✅ Label")
print(y[0])
y = y.astype(np.uint8)
# Zahlen Matrix
print("✅ Zahlen Matrix")
i = 1
for number in some_digit:
#28 Spalten
if i < 28:
if number > 0:
print("\x1b[31m{:03d}".format(math.trunc(number.item())), end = '\x1b[0m ')
else:
print("{:03d}".format(math.trunc(number.item())), end = ' ')
else:
print("{:03d}".format(math.trunc(number.item())))
i = 0
i = i+1
# Train-Test-Split
print("✅ Train-Test-Split")
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# Testdaten vorbereiten für die Klassifikation der Ziffer 5
print("✅ Testdaten vorbereiten für die Klassifikation der Ziffer 5")
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)
print(y_train_5)
# Logistische Regression zur binären Klassifikation (Ziffer aus dem Datensatz: 5)
print("✅ Logistische Regression zur binären Klassifikation")
model_log = SGDClassifier(loss="log_loss", max_iter=1000, tol=1e-3, random_state=42)
model_log.fit(X_train, y_train_5)
model_log.predict([some_digit])
# Support Vector Machine zur binären Klassifikation (Ziffer aus dem Datensatz: 0)
print("✅ Support Vector Machine zur binären Klassifikation")
model_hinge = SGDClassifier(loss="hinge", max_iter=1000, tol=1e-3, random_state=42)
model_hinge.fit(X_train, y_train_5)
model_hinge.predict([some_other_digit])
# Evaluation
print("✅ Evaluation")
model = model_hinge
y_train_pred = cross_val_predict(model_hinge, X_train, y_train_5, cv=3)
y_test_pred = cross_val_predict(model_hinge, X_test, y_test_5, cv=3)
#precision_score(y_train_5, y_train_pred)
precision_score(y_test_5, y_test_pred)
#recall_score(y_train_5, y_train_pred)
recall_score(y_test_5, y_test_pred)
#f1_score(y_train_5, y_train_pred)
f1_score(y_test_5, y_test_pred)
# One-versus-One (OvO)
print("✅ One-versus-One (OvO)")
model_ovo = OneVsOneClassifier(SVC(gamma="auto", random_state=42))
model_ovo.fit(X_train[:100], y_train[:100])
model_ovo.predict([some_digit])
# One-versus-the-Rest (OvR)
print("✅ One-versus-the-Rest (OvR)")
model_ovr = OneVsRestClassifier(SVC(gamma="auto", random_state=42))
model_ovr.fit(X_train[:100], y_train[:100])
model_ovr.predict([some_digit])
# Multilabel Classification
print("✅ Multilabel Classification")
y_train_large = (y_train >= 7) # grosse ziffern (7,8,9)
y_train_odd = (y_train % 2 == 1) # ungerade = true, gerade = false
y_multilabel = np.c_[y_train_large, y_train_odd] # 1-D array als spalte in a 2-D array konvertieren
model_knn = KNeighborsClassifier()
model_knn.fit(X_train, y_multilabel)
# Multiclass Multioutput Classification
print("✅ Multiclass Multioutput Classification")
model_svc = SVC(gamma="auto", random_state=42)
model_svc.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5
model_svc.predict([some_digit])
model_svc.classes_