Inital Commit
This commit is contained in:
commit
3fdc0785bb
18 changed files with 1050 additions and 0 deletions
8
.idea/.gitignore
vendored
Normal file
8
.idea/.gitignore
vendored
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
8
.idea/ForStudent.iml
Normal file
8
.idea/ForStudent.iml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
4
.idea/misc.xml
Normal file
4
.idea/misc.xml
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="kipraktikum2" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/ForStudent.iml" filepath="$PROJECT_DIR$/.idea/ForStudent.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
123
ConvolutionalNeuralNetwor.py
Normal file
123
ConvolutionalNeuralNetwor.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
#Import libaries and datasets
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
|
||||||
|
with open(os.path.join("dataset", "train.p"), mode='rb') as training_data:
|
||||||
|
train = pickle.load(training_data)
|
||||||
|
with open(os.path.join("dataset", "valid.p"), mode='rb') as validation_data:
|
||||||
|
valid = pickle.load(validation_data)
|
||||||
|
|
||||||
|
X_train, y_train = train['features'], train['labels']
|
||||||
|
X_valid, y_valid = valid['features'], valid['labels']
|
||||||
|
|
||||||
|
from sklearn.utils import shuffle
|
||||||
|
X_train, y_train = shuffle(X_train, y_train)
|
||||||
|
X_valid, y_valid = shuffle(X_valid, y_valid)
|
||||||
|
|
||||||
|
# Normalize image to [0, 1]
|
||||||
|
X_train_norm = X_train / 255
|
||||||
|
X_valid_norm = X_valid / 255
|
||||||
|
|
||||||
|
# Check that the images have been correctly converted and normalised
|
||||||
|
i = random.randint(1, len(X_train_norm))
|
||||||
|
plt.grid(False)
|
||||||
|
plt.imshow(X_train[i])
|
||||||
|
plt.figure()
|
||||||
|
plt.grid(False)
|
||||||
|
plt.imshow(X_train_norm[i].squeeze(), cmap = 'gray') # cmap
|
||||||
|
|
||||||
|
|
||||||
|
#Select 10-20 different signs from the dataset
|
||||||
|
pictures = []
|
||||||
|
for i in range(20):
|
||||||
|
randomZahl = np.random.randint(1, len(X_train_norm))
|
||||||
|
picture = X_train_norm[randomZahl]
|
||||||
|
while any(np.array_equal(picture, pic) for pic in pictures):
|
||||||
|
randomZahl = np.random.randint(1, len(X_train_norm))
|
||||||
|
picture = X_train_norm[randomZahl]
|
||||||
|
pictures.append(picture)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Plot a random picture from the training dataset
|
||||||
|
z = np.random.randint(1, len(X_train))
|
||||||
|
plt.imshow(X_train[z])
|
||||||
|
plt.show()
|
||||||
|
print("Label: ", y_train[z])
|
||||||
|
|
||||||
|
#for i in range(20):
|
||||||
|
# plt.imshow(pictures[i])
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from tensorflow.keras import datasets, layers, models
|
||||||
|
model = models.Sequential()
|
||||||
|
|
||||||
|
# Only in the first layer you have to select the input_shape of the data (image).
|
||||||
|
# TODO: Replace the question marks:
|
||||||
|
model.add(layers.Conv2D( filters = 2 , kernel_size = ( 3 , 3 ), padding = "same" , activation = 'relu' , input_shape = ( 32 , 32 , 3)))
|
||||||
|
|
||||||
|
# TODO: Add layers to the model:
|
||||||
|
model.add(layers.MaxPool2D(pool_size=(2, 2), strides=None, padding='valid', data_format=None))
|
||||||
|
model.add(layers.Flatten())
|
||||||
|
model.add(layers.Dense(43, activation='softmax'))
|
||||||
|
|
||||||
|
|
||||||
|
# Prints a summary of your network
|
||||||
|
model.summary()
|
||||||
|
|
||||||
|
model.compile(optimizer = 'Adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
|
||||||
|
|
||||||
|
# TODO: Choose the batch size and the epochs
|
||||||
|
history = model.fit(x = X_train_norm,
|
||||||
|
y = y_train,
|
||||||
|
batch_size = 32,
|
||||||
|
epochs = 10,
|
||||||
|
verbose = 1,
|
||||||
|
validation_data = (X_valid_norm, y_valid))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model.save('saved_model/my_model.h5')
|
||||||
|
|
||||||
|
history.history.keys()
|
||||||
|
|
||||||
|
accuracy = history.history['accuracy']
|
||||||
|
val_accuracy = history.history['val_accuracy']
|
||||||
|
loss = history.history['loss']
|
||||||
|
val_loss = history.history['val_loss']
|
||||||
|
|
||||||
|
epochs = range(len(accuracy))
|
||||||
|
plt.plot(epochs, loss, 'b', label = 'Training loss')
|
||||||
|
plt.plot(epochs, val_loss, 'r', label = 'Validation loss')
|
||||||
|
plt.title('Training and Validation loss')
|
||||||
|
|
||||||
|
plt.plot(epochs, accuracy, 'ro', label = 'Training accuracy')
|
||||||
|
plt.plot(epochs, val_accuracy, 'r', label = 'Validation accuracy')
|
||||||
|
plt.title('Training and Validation accuracy')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model = tf.keras.models.load_model('saved_model/my_model.h5')
|
||||||
|
|
||||||
|
score = model.evaluate(X_valid_norm, y_valid)
|
||||||
|
print('Test Accuracy: {}'.format(score[1]))
|
||||||
|
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
predicted_classes = np.argmax(model.predict(X_valid_norm), axis=-1)
|
||||||
|
y_true = y_valid
|
||||||
|
|
||||||
|
cm = confusion_matrix(y_true, predicted_classes)
|
||||||
|
|
||||||
|
plt.figure(figsize = (25, 25))
|
||||||
|
sns.heatmap(cm, annot = True)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
BIN
Environment_mit_Anaconda_aufsetzen.pdf
Normal file
BIN
Environment_mit_Anaconda_aufsetzen.pdf
Normal file
Binary file not shown.
0
FullyConnectedNeuralNetwork.py
Normal file
0
FullyConnectedNeuralNetwork.py
Normal file
10
LoadModel.py
Normal file
10
LoadModel.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
#Import libaries and datasets
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
|
BIN
Praktikum 2 - Aufgabenstellung.pdf
Normal file
BIN
Praktikum 2 - Aufgabenstellung.pdf
Normal file
Binary file not shown.
289
data_exploration.ipynb
Normal file
289
data_exploration.ipynb
Normal file
|
@ -0,0 +1,289 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Exploration of the traffic sign dataset"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# 1. Import libraries and datasets"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pickle\n",
|
||||||
|
"import os"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Get the training and testing dataset\n",
|
||||||
|
"with open(os.path.join(\"dataset\", \"train.p\"), mode='rb') as training_data:\n",
|
||||||
|
" train = pickle.load(training_data)\n",
|
||||||
|
"with open(os.path.join(\"dataset\", \"valid.p\"), mode='rb') as validation_data:\n",
|
||||||
|
" valid = pickle.load(validation_data)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Get the features and labels of the datasets\n",
|
||||||
|
"# The features are the images of the signs\n",
|
||||||
|
"X_train, y_train = train['features'], train['labels']\n",
|
||||||
|
"X_valid, y_valid = valid['features'], valid['labels']"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# 2. Visualize traffic sign dataset"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(\"Number of training examples: \", X_train.shape[0])\n",
|
||||||
|
"print(\"Number of validation examples: \", X_valid.shape[0])\n",
|
||||||
|
"print(\"Image data shape =\", X_train[0].shape)\n",
|
||||||
|
"print(\"Number of classes =\", len(np.unique(y_train)))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Plot a random picture from the training dataset\n",
|
||||||
|
"i = np.random.randint(1, len(X_train))\n",
|
||||||
|
"plt.grid(False)\n",
|
||||||
|
"plt.imshow(X_train[i])\n",
|
||||||
|
"print(\"Label: \", y_train[i])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Plot (width x height) pictures from the training dataset\n",
|
||||||
|
"grid_width = 5\n",
|
||||||
|
"grid_height = 4\n",
|
||||||
|
"\n",
|
||||||
|
"fig, axes = plt.subplots(grid_height, grid_width, figsize = (10,10))\n",
|
||||||
|
"axes = axes.ravel()\n",
|
||||||
|
"\n",
|
||||||
|
"for i in np.arange(0, grid_width * grid_height):\n",
|
||||||
|
" index = np.random.randint(0, len(X_train))\n",
|
||||||
|
" axes[i].imshow(X_train[index])\n",
|
||||||
|
" axes[i].set_title(y_train[index], fontsize = 15)\n",
|
||||||
|
" axes[i].axis('off')\n",
|
||||||
|
"\n",
|
||||||
|
"plt.subplots_adjust(hspace = 0.3)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Plotting histograms of the count of each sign\n",
|
||||||
|
"def histogram_plot(dataset: np.ndarray, label: str):\n",
|
||||||
|
" \"\"\" Plots a histogram of the dataset\n",
|
||||||
|
"\n",
|
||||||
|
" Args:\n",
|
||||||
|
" dataset: The input data to be plotted as a histogram.\n",
|
||||||
|
" label: The label of the histogram.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" hist, bins = np.histogram(dataset, bins=43)\n",
|
||||||
|
" width = 0.8 * (bins[1] - bins[0])\n",
|
||||||
|
" center = (bins[:-1] + bins[1:]) / 2\n",
|
||||||
|
" plt.bar(center, hist, align='center', width=width)\n",
|
||||||
|
" plt.xlabel(label)\n",
|
||||||
|
" plt.ylabel(\"Image count\")\n",
|
||||||
|
" plt.show()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"histogram_plot(y_train, \"Training examples\")\n",
|
||||||
|
"histogram_plot(y_valid, \"Validation examples\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# A list of all classes:\n",
|
||||||
|
"- 0 = Speed limit (20km/h)\n",
|
||||||
|
"- 1 = Speed limit (30km/h)\n",
|
||||||
|
"- 2 = Speed limit (50km/h)\n",
|
||||||
|
"- 3 = Speed limit (60km/h)\n",
|
||||||
|
"- 4 = Speed limit (70km/h)\n",
|
||||||
|
"- 5 = Speed limit (80km/h)\n",
|
||||||
|
"- 6 = End of speed limit (80km/h)\n",
|
||||||
|
"- 7 = Speed limit (100km/h)\n",
|
||||||
|
"- 8 = Speed limit (120km/h)\n",
|
||||||
|
"- 9 = No passing\n",
|
||||||
|
"- 10 = No passing for vehicles over 3.5 metric tons\n",
|
||||||
|
"- 11 = Right-of-way at the next intersection\n",
|
||||||
|
"- 12 = Priority road\n",
|
||||||
|
"- 13 = Yield\n",
|
||||||
|
"- 14 = Stop\n",
|
||||||
|
"- 15 = No vehicles\n",
|
||||||
|
"- 16 = Vehicles over 3.5 metric tons prohibited\n",
|
||||||
|
"- 17 = No entry\n",
|
||||||
|
"- 18 = General caution\n",
|
||||||
|
"- 19 = Dangerous curve to the left\n",
|
||||||
|
"- 20 = Dangerous curve to the right\n",
|
||||||
|
"- 21 = Double curve\n",
|
||||||
|
"- 22 = Bumpy road\n",
|
||||||
|
"- 23 = Slippery road\n",
|
||||||
|
"- 24 = Road narrows on the right\n",
|
||||||
|
"- 25 = Road work\n",
|
||||||
|
"- 26 = Traffic signals\n",
|
||||||
|
"- 27 = Pedestrians\n",
|
||||||
|
"- 28 = Children crossing\n",
|
||||||
|
"- 29 = Bicycles crossing\n",
|
||||||
|
"- 30 = Beware of ice/snow\n",
|
||||||
|
"- 31 = Wild animals crossing\n",
|
||||||
|
"- 32 = End of all speed and passing limits\n",
|
||||||
|
"- 33 = Turn right ahead\n",
|
||||||
|
"- 34 = Turn left ahead\n",
|
||||||
|
"- 35 = Ahead only\n",
|
||||||
|
"- 36 = Go straight or right\n",
|
||||||
|
"- 37 = Go straight or left\n",
|
||||||
|
"- 38 = Keep right\n",
|
||||||
|
"- 39 = Keep left\n",
|
||||||
|
"- 40 = Roundabout mandatory\n",
|
||||||
|
"- 41 = End of no passing\n",
|
||||||
|
"- 42 = End of no passing by vehicles over 3.5 metric tons"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
BIN
dataset/train.p
Normal file
BIN
dataset/train.p
Normal file
Binary file not shown.
BIN
dataset/valid.p
Normal file
BIN
dataset/valid.p
Normal file
Binary file not shown.
15
environment.yml
Normal file
15
environment.yml
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
name: kipraktikum2
|
||||||
|
dependencies:
|
||||||
|
- python=3.9
|
||||||
|
- numpy
|
||||||
|
- matplotlib
|
||||||
|
- pip
|
||||||
|
- seaborn
|
||||||
|
- scikit-learn
|
||||||
|
- scikit-image
|
||||||
|
- jupyter
|
||||||
|
- pandas
|
||||||
|
- scipy
|
||||||
|
- pip:
|
||||||
|
- tensorflow
|
||||||
|
- optuna
|
84
main.py
Normal file
84
main.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
|
||||||
|
#Import libaries and datasets
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
|
||||||
|
with open(os.path.join("dataset", "train.p"), mode='rb') as training_data:
|
||||||
|
train = pickle.load(training_data)
|
||||||
|
with open(os.path.join("dataset", "valid.p"), mode='rb') as validation_data:
|
||||||
|
valid = pickle.load(validation_data)
|
||||||
|
|
||||||
|
X_train, y_train = train['features'], train['labels']
|
||||||
|
X_valid, y_valid = valid['features'], valid['labels']
|
||||||
|
|
||||||
|
# TODO: Select 10-20 different signs from the dataset
|
||||||
|
|
||||||
|
#Visualize traffic sign dataset
|
||||||
|
print("Number of training examples: ", X_train.shape[0])
|
||||||
|
print("Number of validation examples: ", X_valid.shape[0])
|
||||||
|
print("Image data shape =", X_train[0].shape)
|
||||||
|
print("Number of classes =", len(np.unique(y_train)))
|
||||||
|
|
||||||
|
# Plot a random picture from the training dataset
|
||||||
|
i = np.random.randint(1, len(X_train))
|
||||||
|
plt.grid(False)
|
||||||
|
plt.imshow(X_train[i])
|
||||||
|
print("Label: ", y_train[i])
|
||||||
|
|
||||||
|
# Plot (width x height) pictures from the training dataset
|
||||||
|
grid_width = 5
|
||||||
|
grid_height = 4
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(grid_height, grid_width, figsize = (10,10))
|
||||||
|
axes = axes.ravel()
|
||||||
|
|
||||||
|
for i in np.arange(0, grid_width * grid_height):
|
||||||
|
index = np.random.randint(0, len(X_train))
|
||||||
|
axes[i].imshow(X_train[index])
|
||||||
|
axes[i].set_title(y_train[index], fontsize = 15)
|
||||||
|
axes[i].axis('off')
|
||||||
|
|
||||||
|
plt.subplots_adjust(hspace = 0.3)
|
||||||
|
|
||||||
|
# Plotting histograms of the count of each sign
|
||||||
|
def histogram_plot(dataset: np.ndarray, label: str):
|
||||||
|
""" Plots a histogram of the dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The input data to be plotted as a histogram.
|
||||||
|
label: The label of the histogram.
|
||||||
|
"""
|
||||||
|
hist, bins = np.histogram(dataset, bins=43)
|
||||||
|
width = 0.8 * (bins[1] - bins[0])
|
||||||
|
center = (bins[:-1] + bins[1:]) / 2
|
||||||
|
plt.bar(center, hist, align='center', width=width)
|
||||||
|
plt.xlabel(label)
|
||||||
|
plt.ylabel("Image count")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
histogram_plot(y_train, "Training examples")
|
||||||
|
histogram_plot(y_valid, "Validation examples")
|
||||||
|
|
||||||
|
#Shuffle and configure the datasets
|
||||||
|
|
||||||
|
from sklearn.utils import shuffle
|
||||||
|
X_train, y_train = shuffle(X_train, y_train)
|
||||||
|
X_valid, y_valid = shuffle(X_valid, y_valid)
|
||||||
|
|
||||||
|
# Normalize image to [0, 1]
|
||||||
|
X_train_norm = X_train / 255
|
||||||
|
X_valid_norm = X_valid / 255
|
||||||
|
|
||||||
|
# Check that the images have been correctly converted and normalised
|
||||||
|
i = random.randint(1, len(X_train_norm))
|
||||||
|
plt.grid(False)
|
||||||
|
plt.imshow(X_train[i])
|
||||||
|
plt.figure()
|
||||||
|
plt.grid(False)
|
||||||
|
plt.imshow(X_train_norm[i].squeeze(), cmap = 'gray') # cmap
|
BIN
saved_model/my_model.h5
Normal file
BIN
saved_model/my_model.h5
Normal file
Binary file not shown.
489
student_task.ipynb
Normal file
489
student_task.ipynb
Normal file
|
@ -0,0 +1,489 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# German traffic sign classification\n",
|
||||||
|
"\n",
|
||||||
|
"### Aufgaben:\n",
|
||||||
|
"1. Suchen Sie sich 10-20 unterschiedliche Schilder aus, welche Sie mit Ihrem Modell klassifizieren wollen. Trainieren Sie dann ein Modell Ihrer Wahl. Wählen Sie danach eine der folgenden zusätzlichen Aufgaben aus:\n",
|
||||||
|
" - Erstellen Sie ein Convolutional Neural Network und ein Fully Connected Neural Network und vergleichen Sie die Leistung beider Modelle mit Ihrem Modell.\n",
|
||||||
|
" - Trainieren Sie ein Modell mit farbigen und ein Modell mit grauen Bildern (Farbbilder umgewandelt in Graustufen) und vergleichen Sie die Performance von beiden Modellen mit Ihrem Modell.\n",
|
||||||
|
" - Erstellen Sie zwei weitere Modelle und vergleichen Sie deren Performance. Das zweite Modell soll alle 43 Verkehrsschilder klassifizieren und das dritte Modell nur ein Verkehrsschild erkennen können.\n",
|
||||||
|
" - Benutzen Sie das Framework [Optuna](https://optuna.readthedocs.io/en/stable/index.html), um die Hyperparameter (Anzahl Layer, Anzahl Neuronen pro Layer, …) Ihres Modells zu optimieren. Welchen Vorteil bringt es [Optuna](https://optuna.readthedocs.io/en/stable/index.html) zu nutzen?\n",
|
||||||
|
" - Trainieren Sie ein Modell mit einer hohen Accuracy und ein Modell mit einer geringen Latenz. Wie stark unterscheidet sich die Latenz und Accuracy von beiden Modellen. Die Latenz ist die Zeit, die das Modell braucht, ein Verkehrsschild zu klassifizieren.\n",
|
||||||
|
" - Machen Sie eigene Bilder von Schildern und testen Sie Ihr Modell mit den eigenen Bildern. Variieren Sie Lichtverhältnisse, Wetter, … Die Bilder in den Trainingsdaten haben eine Auflösung von 32 x 32 Pixeln.\n",
|
||||||
|
"2. Sie bekommen ein paar Tage vor der Abgabe einen Testdatensatz. Testet einmal eure erstellten Modelle mit den Testdaten und prüft, wie gut die Modelle unbekannte Daten klassifizieren können.\n",
|
||||||
|
" **Verändern Sie danach <u>nicht</u> mehr die Modelle, um ein besseres Ergebnis zu bekommen**. Das Ziel ist es zu prüfen, wie gut eure Modelle in der Realität funktionieren würden.\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Import libraries and datasets"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import tensorflow as tf\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import seaborn as sns\n",
|
||||||
|
"import pickle\n",
|
||||||
|
"import random\n",
|
||||||
|
"import os"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open(os.path.join(\"dataset\", \"train.p\"), mode='rb') as training_data:\n",
|
||||||
|
" train = pickle.load(training_data)\n",
|
||||||
|
"with open(os.path.join(\"dataset\", \"valid.p\"), mode='rb') as validation_data:\n",
|
||||||
|
" valid = pickle.load(validation_data)\n",
|
||||||
|
"\n",
|
||||||
|
"X_train, y_train = train['features'], train['labels']\n",
|
||||||
|
"X_valid, y_valid = valid['features'], valid['labels']"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# TODO: Select 10-20 different signs from the dataset\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Shuffle and configure the datasets"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.utils import shuffle\n",
|
||||||
|
"X_train, y_train = shuffle(X_train, y_train)\n",
|
||||||
|
"X_valid, y_valid = shuffle(X_valid, y_valid)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Normalize image to [0, 1]\n",
|
||||||
|
"X_train_norm = X_train / 255\n",
|
||||||
|
"X_valid_norm = X_valid / 255"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Check that the images have been correctly converted and normalised\n",
|
||||||
|
"i = random.randint(1, len(X_train_norm))\n",
|
||||||
|
"plt.grid(False)\n",
|
||||||
|
"plt.imshow(X_train[i])\n",
|
||||||
|
"plt.figure()\n",
|
||||||
|
"plt.grid(False)\n",
|
||||||
|
"plt.imshow(X_train_norm[i].squeeze(), cmap = 'gray') # cmap"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Create the Convolutional Neural Network with keras\n",
|
||||||
|
"For example a CNN that recognises handwritten numbers: [https://adamharley.com/nn_vis/cnn/2d.html](https://adamharley.com/nn_vis/cnn/2d.html)\n",
|
||||||
|
"Here a fully connected neural network that also recognise handwritten numbers: [https://adamharley.com/nn_vis/mlp/2d.html](https://adamharley.com/nn_vis/mlp/2d.html)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"Use CNN.add() to add a [Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers) to your model. Here is a list of layers that might be useful:\n",
|
||||||
|
"- [Convolution Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D): layers.Conv2D()\n",
|
||||||
|
"- [Average Pooling](https://www.tensorflow.org/api_docs/python/tf/keras/layers/AveragePooling2D): layers.AveragePooling2D()\n",
|
||||||
|
"- [Max Pooling](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MaxPool2D): layers.MaxPool2D()\n",
|
||||||
|
"- [Dropout](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout): layers.Dropout()\n",
|
||||||
|
"- [Flattens](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Flatten) the input. 2D -> 1D: layers.Flatten()\n",
|
||||||
|
"- [Densely-connected NN layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense): layers.Dense()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from tensorflow.keras import datasets, layers, models\n",
|
||||||
|
"model = models.Sequential()\n",
|
||||||
|
"\n",
|
||||||
|
"# Only in the first layer you have to select the input_shape of the data (image).\n",
|
||||||
|
"# TODO: Replace the question marks:\n",
|
||||||
|
"# model.add(layers.Conv2D( filters = ? , kernel_size = ( ? , ? ), padding = ? , activation = ? , input_shape = ( ? , ? , ?)))\n",
|
||||||
|
"\n",
|
||||||
|
"# TODO: Add layers to the model:\n",
|
||||||
|
"\n",
|
||||||
|
"# Prints a summary of your network\n",
|
||||||
|
"model.summary()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Compile your model\n",
|
||||||
|
"When you want, you can change the [optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) or the [loss function](https://www.tensorflow.org/api_docs/python/tf/keras/losses)."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model.compile(optimizer = 'Adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Train your model\n",
|
||||||
|
"The documentation of the fit method: [https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# TODO: Choose the batch size and the epochs\n",
|
||||||
|
"history = model.fit(x = X_train_norm,\n",
|
||||||
|
" y = y_train,\n",
|
||||||
|
" batch_size = ...,\n",
|
||||||
|
" epochs = ...,\n",
|
||||||
|
" verbose = 1,\n",
|
||||||
|
" validation_data = (X_valid_norm, y_valid))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Save your model\n",
|
||||||
|
"Create a folder for your models"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model.save('saved_model/my_model.h5')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Analyse the results"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"history.history.keys()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"accuracy = history.history['accuracy']\n",
|
||||||
|
"val_accuracy = history.history['val_accuracy']\n",
|
||||||
|
"loss = history.history['loss']\n",
|
||||||
|
"val_loss = history.history['val_loss']"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"epochs = range(len(accuracy))\n",
|
||||||
|
"plt.plot(epochs, loss, 'b', label = 'Training loss')\n",
|
||||||
|
"plt.plot(epochs, val_loss, 'r', label = 'Validation loss')\n",
|
||||||
|
"plt.title('Training and Validation loss')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plt.plot(epochs, accuracy, 'ro', label = 'Training accuracy')\n",
|
||||||
|
"plt.plot(epochs, val_accuracy, 'r', label = 'Validation accuracy')\n",
|
||||||
|
"plt.title('Training and Validation accuracy')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Load your model"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = tf.keras.models.load_model('saved_model/my_model.h5')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# Test your model with the test dataset\n",
|
||||||
|
"If you don't have the test dataset use the validation dataset."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"score = model.evaluate(X_valid_norm, y_valid)\n",
|
||||||
|
"print('Test Accuracy: {}'.format(score[1]))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n",
|
||||||
|
"is_executing": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.metrics import confusion_matrix\n",
|
||||||
|
"predicted_classes = np.argmax(model.predict(X_valid_norm), axis=-1)\n",
|
||||||
|
"y_true = y_valid\n",
|
||||||
|
"\n",
|
||||||
|
"cm = confusion_matrix(y_true, predicted_classes)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plt.figure(figsize = (25, 25))\n",
|
||||||
|
"sns.heatmap(cm, annot = True)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
Loading…
Reference in a new issue