ki-praktikum/MS2/data_exploration.ipynb
2023-05-24 14:48:55 +02:00

289 lines
6.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Exploration of the traffic sign dataset"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 1. Import libraries and datasets"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pickle\n",
"import os"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# Get the training and testing dataset\n",
"with open(os.path.join(\"dataset\", \"train.p\"), mode='rb') as training_data:\n",
" train = pickle.load(training_data)\n",
"with open(os.path.join(\"dataset\", \"valid.p\"), mode='rb') as validation_data:\n",
" valid = pickle.load(validation_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# Get the features and labels of the datasets\n",
"# The features are the images of the signs\n",
"X_train, y_train = train['features'], train['labels']\n",
"X_valid, y_valid = valid['features'], valid['labels']"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 2. Visualize traffic sign dataset"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"print(\"Number of training examples: \", X_train.shape[0])\n",
"print(\"Number of validation examples: \", X_valid.shape[0])\n",
"print(\"Image data shape =\", X_train[0].shape)\n",
"print(\"Number of classes =\", len(np.unique(y_train)))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": true
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Plot a random picture from the training dataset\n",
"i = np.random.randint(1, len(X_train))\n",
"plt.grid(False)\n",
"plt.imshow(X_train[i])\n",
"print(\"Label: \", y_train[i])"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": true
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Plot (width x height) pictures from the training dataset\n",
"grid_width = 5\n",
"grid_height = 4\n",
"\n",
"fig, axes = plt.subplots(grid_height, grid_width, figsize = (10,10))\n",
"axes = axes.ravel()\n",
"\n",
"for i in np.arange(0, grid_width * grid_height):\n",
" index = np.random.randint(0, len(X_train))\n",
" axes[i].imshow(X_train[index])\n",
" axes[i].set_title(y_train[index], fontsize = 15)\n",
" axes[i].axis('off')\n",
"\n",
"plt.subplots_adjust(hspace = 0.3)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": true
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"# Plotting histograms of the count of each sign\n",
"def histogram_plot(dataset: np.ndarray, label: str):\n",
" \"\"\" Plots a histogram of the dataset\n",
"\n",
" Args:\n",
" dataset: The input data to be plotted as a histogram.\n",
" label: The label of the histogram.\n",
" \"\"\"\n",
" hist, bins = np.histogram(dataset, bins=43)\n",
" width = 0.8 * (bins[1] - bins[0])\n",
" center = (bins[:-1] + bins[1:]) / 2\n",
" plt.bar(center, hist, align='center', width=width)\n",
" plt.xlabel(label)\n",
" plt.ylabel(\"Image count\")\n",
" plt.show()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"histogram_plot(y_train, \"Training examples\")\n",
"histogram_plot(y_valid, \"Validation examples\")"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": true
}
}
},
{
"cell_type": "markdown",
"source": [
"# A list of all classes:\n",
"- 0 = Speed limit (20km/h)\n",
"- 1 = Speed limit (30km/h)\n",
"- 2 = Speed limit (50km/h)\n",
"- 3 = Speed limit (60km/h)\n",
"- 4 = Speed limit (70km/h)\n",
"- 5 = Speed limit (80km/h)\n",
"- 6 = End of speed limit (80km/h)\n",
"- 7 = Speed limit (100km/h)\n",
"- 8 = Speed limit (120km/h)\n",
"- 9 = No passing\n",
"- 10 = No passing for vehicles over 3.5 metric tons\n",
"- 11 = Right-of-way at the next intersection\n",
"- 12 = Priority road\n",
"- 13 = Yield\n",
"- 14 = Stop\n",
"- 15 = No vehicles\n",
"- 16 = Vehicles over 3.5 metric tons prohibited\n",
"- 17 = No entry\n",
"- 18 = General caution\n",
"- 19 = Dangerous curve to the left\n",
"- 20 = Dangerous curve to the right\n",
"- 21 = Double curve\n",
"- 22 = Bumpy road\n",
"- 23 = Slippery road\n",
"- 24 = Road narrows on the right\n",
"- 25 = Road work\n",
"- 26 = Traffic signals\n",
"- 27 = Pedestrians\n",
"- 28 = Children crossing\n",
"- 29 = Bicycles crossing\n",
"- 30 = Beware of ice/snow\n",
"- 31 = Wild animals crossing\n",
"- 32 = End of all speed and passing limits\n",
"- 33 = Turn right ahead\n",
"- 34 = Turn left ahead\n",
"- 35 = Ahead only\n",
"- 36 = Go straight or right\n",
"- 37 = Go straight or left\n",
"- 38 = Keep right\n",
"- 39 = Keep left\n",
"- 40 = Roundabout mandatory\n",
"- 41 = End of no passing\n",
"- 42 = End of no passing by vehicles over 3.5 metric tons"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}