289 lines
No EOL
6.8 KiB
Text
289 lines
No EOL
6.8 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# Exploration of the traffic sign dataset"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# 1. Import libraries and datasets"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"outputs": [],
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import pickle\n",
|
|
"import os"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get the training and testing dataset\n",
|
|
"with open(os.path.join(\"dataset\", \"train.p\"), mode='rb') as training_data:\n",
|
|
" train = pickle.load(training_data)\n",
|
|
"with open(os.path.join(\"dataset\", \"valid.p\"), mode='rb') as validation_data:\n",
|
|
" valid = pickle.load(validation_data)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get the features and labels of the datasets\n",
|
|
"# The features are the images of the signs\n",
|
|
"X_train, y_train = train['features'], train['labels']\n",
|
|
"X_valid, y_valid = valid['features'], valid['labels']"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# 2. Visualize traffic sign dataset"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"Number of training examples: \", X_train.shape[0])\n",
|
|
"print(\"Number of validation examples: \", X_valid.shape[0])\n",
|
|
"print(\"Image data shape =\", X_train[0].shape)\n",
|
|
"print(\"Number of classes =\", len(np.unique(y_train)))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n",
|
|
"is_executing": true
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot a random picture from the training dataset\n",
|
|
"i = np.random.randint(1, len(X_train))\n",
|
|
"plt.grid(False)\n",
|
|
"plt.imshow(X_train[i])\n",
|
|
"print(\"Label: \", y_train[i])"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n",
|
|
"is_executing": true
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot (width x height) pictures from the training dataset\n",
|
|
"grid_width = 5\n",
|
|
"grid_height = 4\n",
|
|
"\n",
|
|
"fig, axes = plt.subplots(grid_height, grid_width, figsize = (10,10))\n",
|
|
"axes = axes.ravel()\n",
|
|
"\n",
|
|
"for i in np.arange(0, grid_width * grid_height):\n",
|
|
" index = np.random.randint(0, len(X_train))\n",
|
|
" axes[i].imshow(X_train[index])\n",
|
|
" axes[i].set_title(y_train[index], fontsize = 15)\n",
|
|
" axes[i].axis('off')\n",
|
|
"\n",
|
|
"plt.subplots_adjust(hspace = 0.3)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n",
|
|
"is_executing": true
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plotting histograms of the count of each sign\n",
|
|
"def histogram_plot(dataset: np.ndarray, label: str):\n",
|
|
" \"\"\" Plots a histogram of the dataset\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" dataset: The input data to be plotted as a histogram.\n",
|
|
" label: The label of the histogram.\n",
|
|
" \"\"\"\n",
|
|
" hist, bins = np.histogram(dataset, bins=43)\n",
|
|
" width = 0.8 * (bins[1] - bins[0])\n",
|
|
" center = (bins[:-1] + bins[1:]) / 2\n",
|
|
" plt.bar(center, hist, align='center', width=width)\n",
|
|
" plt.xlabel(label)\n",
|
|
" plt.ylabel(\"Image count\")\n",
|
|
" plt.show()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"histogram_plot(y_train, \"Training examples\")\n",
|
|
"histogram_plot(y_valid, \"Validation examples\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n",
|
|
"is_executing": true
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# A list of all classes:\n",
|
|
"- 0 = Speed limit (20km/h)\n",
|
|
"- 1 = Speed limit (30km/h)\n",
|
|
"- 2 = Speed limit (50km/h)\n",
|
|
"- 3 = Speed limit (60km/h)\n",
|
|
"- 4 = Speed limit (70km/h)\n",
|
|
"- 5 = Speed limit (80km/h)\n",
|
|
"- 6 = End of speed limit (80km/h)\n",
|
|
"- 7 = Speed limit (100km/h)\n",
|
|
"- 8 = Speed limit (120km/h)\n",
|
|
"- 9 = No passing\n",
|
|
"- 10 = No passing for vehicles over 3.5 metric tons\n",
|
|
"- 11 = Right-of-way at the next intersection\n",
|
|
"- 12 = Priority road\n",
|
|
"- 13 = Yield\n",
|
|
"- 14 = Stop\n",
|
|
"- 15 = No vehicles\n",
|
|
"- 16 = Vehicles over 3.5 metric tons prohibited\n",
|
|
"- 17 = No entry\n",
|
|
"- 18 = General caution\n",
|
|
"- 19 = Dangerous curve to the left\n",
|
|
"- 20 = Dangerous curve to the right\n",
|
|
"- 21 = Double curve\n",
|
|
"- 22 = Bumpy road\n",
|
|
"- 23 = Slippery road\n",
|
|
"- 24 = Road narrows on the right\n",
|
|
"- 25 = Road work\n",
|
|
"- 26 = Traffic signals\n",
|
|
"- 27 = Pedestrians\n",
|
|
"- 28 = Children crossing\n",
|
|
"- 29 = Bicycles crossing\n",
|
|
"- 30 = Beware of ice/snow\n",
|
|
"- 31 = Wild animals crossing\n",
|
|
"- 32 = End of all speed and passing limits\n",
|
|
"- 33 = Turn right ahead\n",
|
|
"- 34 = Turn left ahead\n",
|
|
"- 35 = Ahead only\n",
|
|
"- 36 = Go straight or right\n",
|
|
"- 37 = Go straight or left\n",
|
|
"- 38 = Keep right\n",
|
|
"- 39 = Keep left\n",
|
|
"- 40 = Roundabout mandatory\n",
|
|
"- 41 = End of no passing\n",
|
|
"- 42 = End of no passing by vehicles over 3.5 metric tons"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |