ki-Praktikum-MS2/TestMitValid.py
2023-06-04 12:26:34 +02:00

52 lines
No EOL
1.3 KiB
Python

#Import libaries and datasets
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import pandas as pd
import seaborn as sns
import pickle
import random
import os
with open(os.path.join("dataset", "train.p"), mode='rb') as training_data:
train = pickle.load(training_data)
with open(os.path.join("dataset", "valid.p"), mode='rb') as validation_data:
valid = pickle.load(validation_data)
X_train, y_train = train['features'], train['labels']
X_valid, y_valid = valid['features'], valid['labels']
from sklearn.utils import shuffle
X_train, y_train = shuffle(X_train, y_train)
X_valid, y_valid = shuffle(X_valid, y_valid)
# Normalize image to [0, 1]
X_train_norm = X_train / 255
X_valid_norm = X_valid / 255
mask = np.isin(y_valid, range(20))
X_valid_subset = X_valid_norm[mask]
y_valid_subset = y_valid[mask]
#convolutionalNeuralNetwork
#fullyConnectedNeuralNetwork
model = tf.keras.models.load_model('saved_model/convolutionalNeuralNetwork.h5')
score = model.evaluate(X_valid_subset, y_valid_subset)
print('Test Accuracy: {}'.format(score[1]))
from sklearn.metrics import confusion_matrix
predicted_classes = np.argmax(model.predict(X_valid_subset), axis=-1)
y_true = y_valid_subset
cm = confusion_matrix(y_true, predicted_classes)
plt.figure(figsize = (25, 25))
sns.heatmap(cm, annot = True)
plt.show()