spce0038-machine-learning-w.../week2/slides/Lecture04_PerformanceAnalysis.ipynb

2307 lines
281 KiB
Plaintext
Raw Normal View History

2025-02-22 19:16:55 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Lecture 4: Performance analysis\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"![](https://www.tensorflow.org/images/colab_logo_32px.png)\n",
"[Run in colab](https://colab.research.google.com/drive/1qRS7NIEctu9MaqafI5IpK9RWG3yyMYO1)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:45.881490Z",
"iopub.status.busy": "2024-01-10T00:13:45.881060Z",
"iopub.status.idle": "2024-01-10T00:13:45.891554Z",
"shell.execute_reply": "2024-01-10T00:13:45.891079Z"
},
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last executed: 2024-01-10 00:13:45\n"
]
}
],
"source": [
"import datetime\n",
"now = datetime.datetime.now()\n",
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Examining datasets\n",
"\n",
"Use the MNIST digit dataset as a worked example in this lecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Fetch MNIST data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:45.930175Z",
"iopub.status.busy": "2024-01-10T00:13:45.929703Z",
"iopub.status.idle": "2024-01-10T00:13:45.985455Z",
"shell.execute_reply": "2024-01-10T00:13:45.984920Z"
}
},
"outputs": [],
"source": [
"# Common imports\n",
"import os\n",
"import numpy as np\n",
"np.random.seed(42) # To make this notebook's output stable across runs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:45.990862Z",
"iopub.status.busy": "2024-01-10T00:13:45.989504Z",
"iopub.status.idle": "2024-01-10T00:13:58.720117Z",
"shell.execute_reply": "2024-01-10T00:13:58.719391Z"
}
},
"outputs": [],
"source": [
"# Fetch MNIST dataset\n",
"from sklearn.datasets import fetch_openml\n",
"#mnist = fetch_openml('mnist_784')\n",
"mnist = fetch_openml('mnist_784', parser=\"pandas\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Extract features and targets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"MNIST dataset is already split into standard training set (first 60,000 images) and test set (last 10,000 images)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:58.723827Z",
"iopub.status.busy": "2024-01-10T00:13:58.723238Z",
"iopub.status.idle": "2024-01-10T00:13:58.739013Z",
"shell.execute_reply": "2024-01-10T00:13:58.738377Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"((60000,), (10000,))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train = mnist.target[:60000].to_numpy(dtype=int)\n",
"y_test = mnist.target[-10000:].to_numpy(dtype=int)\n",
"y_train.shape, y_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:58.742189Z",
"iopub.status.busy": "2024-01-10T00:13:58.741593Z",
"iopub.status.idle": "2024-01-10T00:13:58.747855Z",
"shell.execute_reply": "2024-01-10T00:13:58.747206Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"((60000, 784), (10000, 784))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train = mnist.data[:60000].to_numpy()\n",
"X_test = mnist.data[-10000:].to_numpy()\n",
"X_train.shape, X_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Each datum corresponds to a 28 x 28 image."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:58.750740Z",
"iopub.status.busy": "2024-01-10T00:13:58.750288Z",
"iopub.status.idle": "2024-01-10T00:13:58.754000Z",
"shell.execute_reply": "2024-01-10T00:13:58.753467Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"28.0 28\n"
]
}
],
"source": [
"import math\n",
"n_float = np.sqrt(X_train.shape[1])\n",
"n = math.floor(n_float)\n",
"print(n_float, n)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Plot image of digit"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:58.757013Z",
"iopub.status.busy": "2024-01-10T00:13:58.756399Z",
"iopub.status.idle": "2024-01-10T00:13:59.203839Z",
"shell.execute_reply": "2024-01-10T00:13:59.203168Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIO0lEQVR4nO3cr2+XVwOH4acLVFDExJio2choCA0K/gBmtmQhmVmCWVKzZRmuyRxhbmoKhwCzMDGxP2BBAQJB0KzZNBMDSSsqeF53m3fmPKE/AtflPzlPkyb395izMs/zPAHANE3vHfUHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgA5cdQfwLvj4cOHi3ZPnz4d3vzwww+LzmKZ/f394c3e3t6is168eDG8WVtbG96sr68Pb94GbgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACAexGORJY+Sff311wfwJf9tc3PzUM45c+bM8ObevXuLznr06NHwZnd3d3izsrIyvDnMB/H+/fff4c3p06eHN5988snw5quvvhreTNM0/fjjj4t2B8FNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4LHrc7tq1a8Obf/75Z3gzTdM0z/Pw5urVq4dyzpLH45acc5hnLTlnY2NjeLO6ujq8maZpWl9fH95cuXJleLPk4b3DeojxILkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvLfPs2bPhzeeffz68WfK43Ycffji8maZpunTp0vDmvffGf++8fv16eLPkobVTp04Nbw77rFHnzp07lHM4eG4KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAVuZ5no/6I3hzNjc3hzd//vnn8Oazzz4b3ty6dWt4M03L/iZgGTcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQE0f9Afy327dvL9otedzu008/Hd7cv39/eLPU7u7u8ObRo0fDm1OnTg1vrly5MryB48xNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4x9TOzs6hnfXgwYPhzcrKypv/EN64ra2t4c2dO3eGN6urq8Mbjic3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkJV5nuej/gj+37Nnzxbtfv755+HNy5cvF5016oMPPli0u3z58hv+kqO1t7e3aPfTTz8Nb169ejW8+fvvv4c3586dG95wPLkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBAPDtnz588X7c6fPz+8OXny5PDmr7/+Gt6cOXNmeMPx5KYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDkxFF/ALxrbty4sWi3t7c3vPnmm2+GN148fbe5KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgKzM8zwf9UfAcfD8+fPhzfb29vDm999/H95M0zR9/PHHw5snT54MbzyI925zUwAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCADlx1B8AB+Hhw4fDm+vXrw9vdnZ2hjenT58e3kzTNN29e3d443E7RrkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvEDx+/Hh4s7m5ueis999/f9Fu1P7+/vDm6dOni8769ddfhze3b99edNaoL774Ynjzyy+/LDrL43YcBjcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQlXme56P+iLfd2bNnhzc3b95cdNb6+vrw5sWLF8Ob3377bXjzxx9/DG+maZqW/ItevHhxeLO9vT282draGt6cPHlyeAOHxU0BgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFADIiaP+gHfBhQsXhjfffvvtAXzJm7O6ujq82djYWHTWd999N7z5/vvvhzdra2vDG3jbuCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCszPM8H/VHvO329/eHN/fu3Vt01t7e3qLdqC+//HJ489FHHx3AlwBvkpsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIB/EAiJsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAg/wMoh+Lly7Gm9AAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"some_digit = X_train[41000]\n",
"some_digit_image = some_digit.reshape(n, n)\n",
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
"plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Plot selection of digits"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:59.207815Z",
"iopub.status.busy": "2024-01-10T00:13:59.207477Z",
"iopub.status.idle": "2024-01-10T00:13:59.659642Z",
"shell.execute_reply": "2024-01-10T00:13:59.658941Z"
}
},
"outputs": [],
"source": [
"# Extract digits\n",
"n_digits = 10\n",
"n_images = 10\n",
"example_images = np.zeros([n_images * n_digits, n*n])\n",
"for i in range(n_digits):\n",
" example_images[i*n_images:(i+1)*n_images,:] = X_train[np.where(y_train == i)][0:n_images,:]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:13:59.663112Z",
"iopub.status.busy": "2024-01-10T00:13:59.662527Z",
"iopub.status.idle": "2024-01-10T00:14:02.060635Z",
"shell.execute_reply": "2024-01-10T00:14:02.059958Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1000x1000 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAn8AAAJ8CAYAAACP2sdVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAADTXElEQVR4nOydZ7gTVReFX2yggKgoCKJgBTsq6GfvBRWx995774q9V1QUFRW72BV77x3Fgih2LGBvFCvc74fPuieZm9ybnklmvX9ybzJJzs5MJnPWWXvvVg0NDQ0YY4wxxphEMF21B2CMMcYYYyqHL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxLEDNUeQC68+eabAAwePBiAG264AYBddtkFgIMOOgiAZZddtgqjqyzPPvssAGuttRYADQ0Nafevvvrq1RiWKSGPPPIIABtttBEA88wzDwBXX301AH369AFgrrnmqsLojDG58tVXXwFwxBFHAHDnnXc22eawww4D4KKLLqrcwEzisfJnjDHGGJMgWjVIOooZb7/9duPfa665JgC///57xm07dOgAwM8//1z2cVWL66+/HoBLL70UgPfeew+AqVOnAtC7d28gqKEHHHAAADPMUBPiLgBvvPFG49/LL788AK1atcq47WmnnQbAiSeeWP6BVRgpf5tssknGx/v37w/APffcU7ExlZp3330XgJdeeqnxvv333z+n584222wAvPzyywAsuuiipR1ciXn11VcBGDVqFAAXXHABAJ999lnjNklS7n/44YfGv/faay8ARowYkbZNrX6/pfRtvfXWQNj3mZh33nnTnnP77benPbfW6du3LwAjR44EYLrpsmtNiyyyCAA333wzAN27dwdgzjnnLOcQK8Lzzz/f+PeXX34JwFlnnQXABx98AITVnXPOOQeATTfdFIB27dqVbVxW/owxxhhjEkTslL/XX38dgC222KLxvm+++QYIKtCss84KwEwzzQTAjz/+CMCLL74IwHLLLZf2eC0jxe/GG28E4IUXXkh7XMrf9NNPn3b/J598AoQZVC2w8cYbN/798MMPA03jiiK1SMfLaqutVqbRlYdff/0VSFe9NFP87rvvMj7n1FNPBeD4448v7+BKyMcffwzAvffeC8CQIUMA+OKLLxq3yabyZmPxxRcH4NprrwWCWlxttAJx4IEHAvDUU08B6YpXlNlnnx2Abt26ZXz8jDPOAGDllVcGYI455ijNYCuI9vljjz3WeN9DDz3U7HO00rHffvuVb2Al4PDDDwfg4osvTrt/q622AuDCCy8EgtoHcMcddwCwzTbbpD2n1j2AOn/tuuuuAIwbNw5o/lwe/R3bfvvtgaCS16IC+PjjjwNhNQ6yn9Oj7L333gBccsklALRu3brEo7PyZ4wxxhiTKKqu/E2ZMgWAt956C4Add9wRCD4ICBmtUgak7B199NFAmDlpO82Sa0UZkfojn+Nuu+3W+JjUgr/++ivtOb169QLCjEnKiqgF5U+qz/rrrw/At99+2/jYxIkTgZaVP8UvhUBqS62gWXLqPv/888+BlmO/6667ABgwYECZRlc88umut956QFD2RerpJ1/lT5x33nkAHHnkkQU9v9Rov+Tj3Yqe47Kx2WabAXDLLbcA0KZNm0KGWBHkydSxre+3vtvQ8jGu+OQB1Pd7xhlnLO1g80S/T1Ji9f///vc/IKh2K664YouvNd9886W9hpBKeOihhwJBXYwrOpfJbz527Fgg++pUKtm2ee211wBYZpllSjvYCjDzzDMDTX+780EVTnL1Q+eDlT9jjDHGmARR9VTQffbZB4Bbb7015+eo7t+kSZOAkCGnjDllwsad++67Dwj125544gkgzIIg+2zpqKOOAmDatGlAyJqrJf755x8gPesxafzyyy9AUMDzQd8dZdEpC7iaSMWWWvHAAw8AIc58kM9FfrhUZTiOvPLKK0B5v4vyTMoLdMwxx5TtvQrlueeeA2DbbbcFgie7ECZPngyEVR4hX1ylkTontU7I01eIOrflllsC8PXXXwNBPVRtQN3q8bh5AZWxKo+fxlkMWrXr2bNn0a9VaZrz9ubKUkstBcDmm29e9Gtlw8qfMcYYY0yCqJryJ/XuwQcfBNK9PwBrrLFG49/KApWnp2vXrkDwAUgZeOaZZzK+VtxQLaOdd9454+Op409VAbNt09x2cebkk0/OeduhQ4cCoWbUVVddVZYxlZtDDjkEgMsuuyzrNtq3Le1TZY4pmy4O3H333UDITi+G+eefHwjKR1zVbfkaN9xwQwB+++23vF9D5zgppPLLZUP1wOSvKmc9sHyROlmM4pcNrZJUS/mTx08om7cYP55ikWImn6BeM5pJLE+pasZVm8UWWwzIXsdPq1PNEd3m9NNPB0L9P2X/1gLHHnssUJzXr3379gDMPffcJRlTJqz8GWOMMcYkiIorf8poXWeddYAwa1aWm2bPt912W+Nz5OU788wzAdhzzz2B0Nt06aWXTnsN1Y5SBnFcev5K8ZP6Iz+fMto6deoEBC8jNO1aom01M9Dn11LWXByI9qxtDlX2V5afULxSxeKu8kbRMZpL5puyRVdZZRUg1HhUZw9tp//lsapGTawJEyYAoe92S5xwwgkALLDAAo33Pfroo0Dof3rccccB8Pfff5dsnOXgp59+AnJX/FR/NFUtUsV/fd+1iiG1U15KoffKRVWpFFJrlHmfjVxWNqLbajudG8eMGQME1ancyM8pz58UP9XqKwZl9abWAExFHj/VgJQSLoVQY6sWqsH44YcfAqHebpRM5zxlr+s4vv/++9Mel49Qv3vl9MAVy/Dhw4HcOi/16NEDSK9zWmms/BljjDHGJIiKKX8fffQREGpyaeYq9a5Lly5AqIad6mGRHya1A0RzKHNS1cHzySQuB8rqlccvOgNSZwJ1AlBXD2jqc5JCoBlQ6ra1Qi4qZVTxE1HlLO5eR3WnUQ2sm266Ke1x9alNVetUx/Lyyy8HQr2obFlkUgT1naqG8ifF/p133km7X/tJHSlUr0rZm4oNQr1CKUgLLrggENQeKYPyFcaFgQMH5rW9PEGnnHJKk8f0OaljjfzNOn70fVcNtei5pZrIw5trbc58ttV2yvi+5pprgMplvka7cJRC8csXKcXqFyyFXMpfLjUFy4EyzqXSZUMrGBA8fOphq8oF2WiuL3BckHezuRWAFVZYAYBhw4YB4bNTVYRKEv9P1BhjjDHGlIyyK3/KeFGmrvx46s+rrMA+ffoA8Mcff5TsvaMV0yuNZumqeSbkX5Di11zmp+r9aFYV7XGpGlHKgnvjjTeKGnM5aSm7VypHJuT7KkcGYTmQAqaONe+//z7QVOnYaaedgOYVDHVvkeobJ1SBX11JokjJyqWnpbaN9q2V362QLNpyIsVF2ZfZkMKn88Ass8yS83tIzdGt/FRS/tQDVFmRqhFXSaRuZ0Pqtr7DqR0+xAwz/PdTdPDBBwMhLtUMjCpD8pbKM7bqqqsWMvSc0W9JNl9eJZHqqFUQZQFXS/nr168fELpKRVl00UWBdD9w9HOUgp1NCZZCGCekwKoSRfS3SRVLUrvwKJehQ4cOAOywww5AOM7lade5Thng2Xp+F4OVP2OMMcaYBFF25U8Zt1L8hLJ61J2jHpFvTVXqhaqXK5sxSqo3QrOqzp07Z9xW3sg49/gUUndHjRqV8XGpl5mQOhpH9SsT8mgpA64YpIhK/chW41DKqrLKK8HZZ58NhBlrlFL0pHzyySfTbuOC/MvRbGStaqjKwO677552fynRe2ssuWQaloqouh1Fip+OV3lWo6sXEBS/c889N+1+ecWj6HgrpDNOMVRDWc2GuorI+1ct9Hsmv36U0aNHt/ga6vmt1bAoffv2BeK1sqXvdfQcr17W+j1rzteqig6q26lqKPrMRowYAbi3rzHGGGOMKZK
"text/plain": [
"<Figure size 800x800 with 100 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot digits\n",
"plt.figure(figsize=(10,10))\n",
"fig, axes = plt.subplots(n_digits, n_images, figsize=(8, 8),\n",
" subplot_kw={'xticks':[], 'yticks':[]},\n",
" gridspec_kw=dict(hspace=0.1, wspace=0.1))\n",
"\n",
"for i, ax in enumerate(axes.flat):\n",
" ax.imshow(example_images[i].reshape(n,n), cmap='binary', interpolation='nearest')\n",
" ax.axis(\"off\")\n",
" ax.text(0.05, 0.05, str(i // n_images),\n",
" transform=ax.transAxes, color='green')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:02.064046Z",
"iopub.status.busy": "2024-01-10T00:14:02.063359Z",
"iopub.status.idle": "2024-01-10T00:14:02.067767Z",
"shell.execute_reply": "2024-01-10T00:14:02.067210Z"
}
},
"outputs": [],
"source": [
"def plot_digit(data):\n",
" image = data.reshape(n, n)\n",
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
" interpolation=\"nearest\")\n",
" plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Shuffle training data so not ordered by type."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:02.070642Z",
"iopub.status.busy": "2024-01-10T00:14:02.070400Z",
"iopub.status.idle": "2024-01-10T00:14:02.736323Z",
"shell.execute_reply": "2024-01-10T00:14:02.735619Z"
}
},
"outputs": [],
"source": [
"# Shuffle training data\n",
"import numpy as np\n",
"shuffle_index = np.random.permutation(60000)\n",
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Binary classifier\n",
"\n",
"Construct a classify to distinguish between 5s and all other digits."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:02.739912Z",
"iopub.status.busy": "2024-01-10T00:14:02.739322Z",
"iopub.status.idle": "2024-01-10T00:14:02.742873Z",
"shell.execute_reply": "2024-01-10T00:14:02.742160Z"
}
},
"outputs": [],
"source": [
"y_train_5 = (y_train == 5)\n",
"y_test_5 = (y_test == 5)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:02.745912Z",
"iopub.status.busy": "2024-01-10T00:14:02.745363Z",
"iopub.status.idle": "2024-01-10T00:14:02.751897Z",
"shell.execute_reply": "2024-01-10T00:14:02.751306Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([False, False, False, ..., False, True, False])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_test_5"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Train\n",
"\n",
"Train a linear model using Stochastic Gradient Descent (good for large data-sets, as we will see later...)."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:02.754951Z",
"iopub.status.busy": "2024-01-10T00:14:02.754519Z",
"iopub.status.idle": "2024-01-10T00:14:02.758982Z",
"shell.execute_reply": "2024-01-10T00:14:02.758368Z"
}
},
"outputs": [],
"source": [
"# disable convergence warning from early stopping\n",
"from warnings import simplefilter\n",
"from sklearn.exceptions import ConvergenceWarning\n",
"simplefilter(\"ignore\", category=ConvergenceWarning)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:02.761746Z",
"iopub.status.busy": "2024-01-10T00:14:02.761309Z",
"iopub.status.idle": "2024-01-10T00:14:04.112020Z",
"shell.execute_reply": "2024-01-10T00:14:04.111257Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-r
],
"text/plain": [
"SGDClassifier(max_iter=10, random_state=42)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"sgd_clf = SGDClassifier(random_state=42, max_iter=10);\n",
"sgd_clf.fit(X_train, y_train_5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Recall extracted `some_digit` previously, which was a 5."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.115229Z",
"iopub.status.busy": "2024-01-10T00:14:04.114730Z",
"iopub.status.idle": "2024-01-10T00:14:04.162849Z",
"shell.execute_reply": "2024-01-10T00:14:04.162180Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIO0lEQVR4nO3cr2+XVwOH4acLVFDExJio2choCA0K/gBmtmQhmVmCWVKzZRmuyRxhbmoKhwCzMDGxP2BBAQJB0KzZNBMDSSsqeF53m3fmPKE/AtflPzlPkyb395izMs/zPAHANE3vHfUHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgA5cdQfwLvj4cOHi3ZPnz4d3vzwww+LzmKZ/f394c3e3t6is168eDG8WVtbG96sr68Pb94GbgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACAexGORJY+Sff311wfwJf9tc3PzUM45c+bM8ObevXuLznr06NHwZnd3d3izsrIyvDnMB/H+/fff4c3p06eHN5988snw5quvvhreTNM0/fjjj4t2B8FNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4LHrc7tq1a8Obf/75Z3gzTdM0z/Pw5urVq4dyzpLH45acc5hnLTlnY2NjeLO6ujq8maZpWl9fH95cuXJleLPk4b3DeojxILkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvLfPs2bPhzeeffz68WfK43Ycffji8maZpunTp0vDmvffGf++8fv16eLPkobVTp04Nbw77rFHnzp07lHM4eG4KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAVuZ5no/6I3hzNjc3hzd//vnn8Oazzz4b3ty6dWt4M03L/iZgGTcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQE0f9Afy327dvL9otedzu008/Hd7cv39/eLPU7u7u8ObRo0fDm1OnTg1vrly5MryB48xNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4x9TOzs6hnfXgwYPhzcrKypv/EN64ra2t4c2dO3eGN6urq8Mbjic3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkJV5nuej/gj+37Nnzxbtfv755+HNy5cvF5016oMPPli0u3z58hv+kqO1t7e3aPfTTz8Nb169ejW8+fvvv4c3586dG95wPLkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBAPDtnz588X7c6fPz+8OXny5PDmr7/+Gt6cOXNmeMPx5KYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDkxFF/ALxrbty4sWi3t7c3vPnmm2+GN148fbe5KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgKzM8zwf9UfAcfD8+fPhzfb29vDm999/H95M0zR9/PHHw5snT54MbzyI925zUwAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCADlx1B8AB+Hhw4fDm+vXrw9vdnZ2hjenT58e3kzTNN29e3d443E7RrkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvEDx+/Hh4s7m5ueis999/f9Fu1P7+/vDm6dOni8769ddfhze3b99edNaoL774Ynjzyy+/LDrL43YcBjcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQlXme56P+iLfd2bNnhzc3b95cdNb6+vrw5sWLF8Ob3377bXjzxx9/DG+maZqW/ItevHhxeLO9vT282draGt6cPHlyeAOHxU0BgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFADIiaP+gHfBhQsXhjfffvvtAXzJm7O6ujq82djYWHTWd999N7z5/vvvhzdra2vDG3jbuCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCszPM8H/VHvO329/eHN/fu3Vt01t7e3qLdqC+//HJ489FHHx3AlwBvkpsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIB/EAiJsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAg/wMoh+Lly7Gm9AAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_digit(some_digit)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Predict class:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.166557Z",
"iopub.status.busy": "2024-01-10T00:14:04.166129Z",
"iopub.status.idle": "2024-01-10T00:14:04.173666Z",
"shell.execute_reply": "2024-01-10T00:14:04.173133Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([ True])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"some_digit.shape\n",
"\n",
"sgd_clf.predict([some_digit])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Test accuracy"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.178073Z",
"iopub.status.busy": "2024-01-10T00:14:04.176751Z",
"iopub.status.idle": "2024-01-10T00:14:04.213684Z",
"shell.execute_reply": "2024-01-10T00:14:04.213035Z"
}
},
"outputs": [],
"source": [
"y_test = sgd_clf.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.218677Z",
"iopub.status.busy": "2024-01-10T00:14:04.217226Z",
"iopub.status.idle": "2024-01-10T00:14:04.225740Z",
"shell.execute_reply": "2024-01-10T00:14:04.225190Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"0.9601"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(y_test, y_test_5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Cross-validation"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### n-fold cross-validation\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture04_Images/5-fold-CV.png\" width=\"700\" style=\"display:block; margin:auto\"/>\n",
"\n",
"[Image credit: [VanderPlas](https://github.com/jakevdp/PythonDataScienceHandbook)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercises 2-3 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Consider naive classifier \n",
"\n",
"Classify everying as not 5."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.229463Z",
"iopub.status.busy": "2024-01-10T00:14:04.229066Z",
"iopub.status.idle": "2024-01-10T00:14:04.234410Z",
"shell.execute_reply": "2024-01-10T00:14:04.233859Z"
}
},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"class Never5Classifier(BaseEstimator):\n",
" def fit(self, X, y=None):\n",
" pass\n",
" def predict(self, X):\n",
" return np.zeros((len(X), 1), dtype=bool)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"What accuracy expect?"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.237953Z",
"iopub.status.busy": "2024-01-10T00:14:04.237493Z",
"iopub.status.idle": "2024-01-10T00:14:04.867261Z",
"shell.execute_reply": "2024-01-10T00:14:04.866601Z"
},
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([0.909 , 0.90745, 0.9125 ])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"never_5_clf = Never5Classifier()\n",
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Need to go beyond classification accuracy, especially for skewed datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Confusion matrix\n",
"\n",
"Can gain further insight into performance by examining confusion matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Confusion matrix shows true/false-positive/negative classifications\n",
"\n",
"- True-positive $\\text{TP}$: number of true positives (i.e. *correctly classified* as *positive*)\n",
"- False-positive $\\text{FP}$: number of false positives (i.e. *incorrectly classified* as *positive*)\n",
"- True-negative $\\text{TN}$: number of true negatives (i.e. *correctly classified* as *negative*)\n",
"- False-negative $\\text{FN}$: number of false negatives (i.e. *incorrectly classified* as *negative*)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table>\n",
" <tr>\n",
" <td></td>\n",
" <td></td>\n",
" <td>Predicted</td>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" <td></td>\n",
" <td>Negative</td>\n",
" <td>Positive</td>\n",
" </tr>\n",
" <tr>\n",
" <td>Actual</td>\n",
" <td>Negative</td>\n",
" <td>TN</td>\n",
" <td>FP</td>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" <td>Positive</td>\n",
" <td>FN</td>\n",
" <td>TP</td>\n",
" </tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Cross-validation prediction\n",
"\n",
"`cross_val_predict` performs K-fold cross-validation, returing predictions made on each test fold. Get clean prediction on each test fold, i.e. clean prediction for each instance in the training set."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:04.871959Z",
"iopub.status.busy": "2024-01-10T00:14:04.870596Z",
"iopub.status.idle": "2024-01-10T00:14:07.774513Z",
"shell.execute_reply": "2024-01-10T00:14:07.773793Z"
}
},
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_predict\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Compute confusion matrix"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:07.778441Z",
"iopub.status.busy": "2024-01-10T00:14:07.778158Z",
"iopub.status.idle": "2024-01-10T00:14:07.793206Z",
"shell.execute_reply": "2024-01-10T00:14:07.792567Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[52953, 1626],\n",
" [ 807, 4614]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"conf_matrix = confusion_matrix(y_train_5, y_train_pred)\n",
"conf_matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each row represents actual class, while each colum represents predicted class."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"### Perfect confusion matrix"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:07.796975Z",
"iopub.status.busy": "2024-01-10T00:14:07.796729Z",
"iopub.status.idle": "2024-01-10T00:14:07.810580Z",
"shell.execute_reply": "2024-01-10T00:14:07.809959Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[54579, 0],\n",
" [ 0, 5421]])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train_perfect_predictions = y_train_5\n",
"confusion_matrix(y_train_5, y_train_perfect_predictions)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 4 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Precision and recall"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **Precision**: of predicted positives, proportion that are correctly classified (also called *positive predictive value*).\n",
"\n",
"$$\\text{precision} = \\frac{\\text{TP}}{\\text{TP} + \\text{FP}}$$\n",
"\n",
"\n",
"- **Recall**: of actual positives, proportion that are correctly classified (also called *true positive rate* or *sensitivity*).\n",
"\n",
"$$\\text{recall} = \\frac{\\text{TP}}{\\text{TP} + \\text{FN}}$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Remember:\n",
"<table>\n",
" <tr>\n",
" <td></td>\n",
" <td></td>\n",
" <td>Predicted</td>\n",
" <td></td>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" <td></td>\n",
" <td>Negative</td>\n",
" <td>Positive</td>\n",
" </tr>\n",
" <tr>\n",
" <td>Actual</td>\n",
" <td>Negative</td>\n",
" <td>TN</td>\n",
" <td>FP</td>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" <td>Positive</td>\n",
" <td>FN</td>\n",
" <td>TP</td>\n",
" </tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 5 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### $F_1$ score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$F_1$ score is the *harmonic mean* of the precision and recall."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$$F_1 = \\frac{2}{1/\\text{precision} + 1/\\text{recall}} = 2 \\frac{\\text{precision} \\times \\text{recall}}{\\text{precision} + \\text{recall}} = \\frac{\\text{TP}}{\\text{TP} + \\frac{\\text{FN}+\\text{FP}}{2}}$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$F_1$ favours classifiers that have similar (and high) precision and recall.\n",
"\n",
"Sometimes may wish to favour precision or recall."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 6 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Precision-recall tradeoff"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Under the hood the classifier computes a *score*. Binary decision is then made depending on whether score exceeds some *threshold*.\n",
"\n",
"By changing the threshold, one can change the tradeoff between\n",
"precision and recall."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Scikit-Learn does not let you set the threshold directly but can access scores (confidence score for a sample is, e.g., the signed distance of that sample to classifying hyperplane)."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:07.814908Z",
"iopub.status.busy": "2024-01-10T00:14:07.814670Z",
"iopub.status.idle": "2024-01-10T00:14:07.821596Z",
"shell.execute_reply": "2024-01-10T00:14:07.820991Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([24299.40524152])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_scores = sgd_clf.decision_function([some_digit])\n",
"y_scores"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Can then make prediction for given threshold."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:07.825101Z",
"iopub.status.busy": "2024-01-10T00:14:07.824869Z",
"iopub.status.idle": "2024-01-10T00:14:07.831310Z",
"shell.execute_reply": "2024-01-10T00:14:07.830702Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([ True])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold = 0\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:07.834535Z",
"iopub.status.busy": "2024-01-10T00:14:07.834282Z",
"iopub.status.idle": "2024-01-10T00:14:07.840676Z",
"shell.execute_reply": "2024-01-10T00:14:07.840078Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([False])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"threshold = 200000\n",
"y_some_digit_pred = (y_scores > threshold)\n",
"y_some_digit_pred"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Compute precision and recall for range of thresholds"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:07.844026Z",
"iopub.status.busy": "2024-01-10T00:14:07.843796Z",
"iopub.status.idle": "2024-01-10T00:14:10.475531Z",
"shell.execute_reply": "2024-01-10T00:14:10.474826Z"
}
},
"outputs": [],
"source": [
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n",
" method=\"decision_function\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:10.479280Z",
"iopub.status.busy": "2024-01-10T00:14:10.479036Z",
"iopub.status.idle": "2024-01-10T00:14:10.494777Z",
"shell.execute_reply": "2024-01-10T00:14:10.494077Z"
}
},
"outputs": [],
"source": [
"from sklearn.metrics import precision_recall_curve\n",
"precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:10.499634Z",
"iopub.status.busy": "2024-01-10T00:14:10.497954Z",
"iopub.status.idle": "2024-01-10T00:14:10.701765Z",
"shell.execute_reply": "2024-01-10T00:14:10.701097Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(-700000.0, 700000.0)"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAz8AAAHRCAYAAABEhAeYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACCoklEQVR4nO3dZ3gUVR+G8XvTCySUEAgl9CZFOiIdpYsKUqQoiIpYABEUe8eCCBZsoFJERUGwgEivUhWQLkhLgNAhBdJ33g/zZsOSQBJIMtnk+XHtlTNnzsz+sxuSPJmZMzbDMAxERERERETyOTerCxAREREREckNCj8iIiIiIlIgKPyIiIiIiEiBoPAjIiIiIiIFgsKPiIiIiIgUCAo/IiIiIiJSICj8iIiIiIhIgaDwIyIiIiIiBYLCj4iIiIiIFAgKPyIiIiIiUiBkOfzExMTwyiuv0KlTJ4oVK4bNZmPatGmZ3v7ChQsMGTKEEiVK4O/vT9u2bdmyZUtWyxAREREREcmSLIefM2fO8Prrr7Nnzx5uvvnmLG1rt9vp2rUr3333HU888QTjxo3j1KlTtGnThv3792e1FBERERERkUzzyOoGISEhREREUKpUKf766y8aN26c6W3nzJnDunXrmD17Nj179gSgd+/eVKtWjVdeeYXvvvsuq+WIiIiIiIhkSpaP/Hh7e1OqVKnrerI5c+ZQsmRJevTo4egrUaIEvXv35pdffiE+Pv669isiIiIiIpKRXJ3wYOvWrTRo0AA3N+enbdKkCZcuXWLfvn25WY6IiIiIiBQgWT7t7UZERETQqlWrNP0hISEAHD9+nDp16qRZHx8f73RUyG63c+7cOYoXL47NZsu5gkVEREREJE8zDIPo6GhKly6d5iDLlXI1/MTGxuLt7Z2m38fHx7E+PW+//TavvfZajtYmIiIiIiKuKzw8nLJly15zTK6GH19f33Sv64mLi3OsT89zzz3HU0895ViOjIwkNDSUz5Z9ho+/D4ZhYGA4fQTS7b/8Y3xSPBcTLzr1Xbkd4NS+1vr0tk+0JxKdEI3dsDv6UtoYpNsfERPB2UtnKVmopGNfdsPuXH86/YnJiUQlRBGXFMfFhIvEJ+kaqvzA090TL3cv3GxuuNvccXMzPwZ6B+Lj4YObzc3xsNlsjrafpx+lCpUiNCCUUoVKUdyvOD4ePni6eVLUtygB3gEU9S2Kr4ev0z5S9iMiIjfu/Hnw9gY/v+zb58SJ8OqrZvujj+D229OOKVUK3N1Tl6OjISoq4317ekJwsHPfmTOQmcuy/f2hSJHUZcOA48cz3g6geHH4/9/CAYiNhXPnMrdtmTLOy+fPw6VLGW/n42M+7+VOnoSkpIy3DQyEQoUyV5/kvKioKMqVK0fhwoUzHJur4SdlprgrpfSVLl063e28vb3TPWLUr1E/AgICsrfIfCQuKQ67YQdwCmopy5e3s3tdTEIMcUlxaQJnmhB3RfAzDIP45Hii46NJsidhN+wkG8nYDbvZticTFR/l+NxS9pneIyo+ikuJl0i0JxKbGMuFuAsk2hMdz3flw8AgLimOmIQYEpMTHc+bbE8m2UgmNjGWRHtiDr5j6Uv8/z8H8y3l9KXTOfq8Nswg5enuSaB3IN4e3rjb3PFw88DT3ZPivsUp4lOE4n7FCfAKoLB3YZqXa04hr0J4unvi4eZBUZ+i+Hv54+nmSSGvQnh7pP1/LCKSnz3zDHzxhRkKihSBgACoXNkMLO3aQfXqkNW/NwUHQ6VKkJAANWuaj4wEBKQNCZl1I79qBQZe/3OWLHn9214v/Vrp2jLzx9tcDT/16tVjzZo12O12p/PxNm7ciJ+fH9WqVcvNcvI9Hw+fjAdJlqSErsvDWEo7NimWQ+cPkWwkk5icSKI9kQtxF4hLiiPJnuQIUQnJCVyIu0BCcoKjL+VjZHykGdiSE4lLiiMyPpL4pPg0QexS4iWi4qPSDXHZxcAwa0tKJi4pLlv26eXuhZe7F55u5tGsMgFlKFWoFH6efo6+Yr7FKORViCC/ICoXrew4SlXEpwiebp54unsS4B2Ah1uufvsSkTzg0iUzSCQnQ8WKULUq1KgBXl7pjz93DsLCzLCR8gfh+fPhl19g1Chz2xS7d8Nzz5lHaFL2W7v2tfcP5hGVv/82w0i5cmmDTHi4+fHCBfMBsH07zJtntoOCoHt3GD0aMvtr0KOPmg8Ryboc++0hIiKCyMhIKleujKenJwA9e/Zkzpw5zJ0713GfnzNnzjB79my6deuW7tEdkbwko0AZGhiaS5VcXcqRrYiYCMIjwzkefZyImAjHEa2o+CguxF0gKiGKc7HnnALclY9kI/VIW7I92QxxRjKRcZEkG8lZri0hOYGE5ATH8smLJ6/78/Tx8MHL3QsfDx+K+BTB39MfP08/ivsVx9fDl6I+RalYtCIl/UsS5BdEi9AWBPpc558gRSRDdjukd52xYZiBIOVjipTTkjJ7OphhmKdVXcnDwwworVrBJ5+k9k+cCJedMU+ZMhASAn/9ZS4//LDzfrp3h/QmnfX0hHr1zP2/8QZcfoa+YUDbtmb4AfPIzoULMGwYDBgATZpA3brm6WYnTpgfT582t0tx5gxMmQLNmjmHnwMHYNIkGDjQfH4RyR424/KLWTJp0qRJXLhwgePHj/PZZ5/Ro0cP6tevD8CwYcMIDAxk0KBBTJ8+nUOHDlGhQgUAkpOTadGiBTt37uTpp58mKCiITz/9lLCwMDZv3kz16tUz9fxRUVEEBgYSGRmp095ELGAYBrFJsZy5dIbzsefZfXo3x6OPczb2LInJiSTZk4hJiOFC/AUSkxOJTTJPO4xPiifRbq4/F3uOM5fO5HrtXu5euNvccXczT+HzcveiuG9xivkWo0KRClQvXp06JevQIKRBngizInmNYcA335jXlHTokNrfrBls2GC2U0LAlU6fNo90AEybBg89BDfdBA0bQqNG5jUYW7fCl19C+/bQrRsMGWKOf+klePPNq9fVrBmsW5e6nNHZLw88AF9/nfnxXl5mYLv8OpqdOyGdSWoB8yhQWFja/uRk83NctAhWrYLVq83raebPh65dU8fNn29+/mAGr759zYAWHJz10+RE8rusZIPrCj8VKlTgyJEj6a5LCTvphR+A8+fP8/TTT/Pzzz8TGxtL48aNGT9+PI0aNcr08yv8iOQPsYmxRMZHOgLThbgLRCdEEx0fzf5z+zkRc4KLCRc5G3uWhOQEEu2JXEy46DhtMCE5gfNx57mYcNE8OnUdR6My4uvhi7eHN1WKVaFqsaqEFAoh2D+Y8kXKE1IohMrFKlM24Nozy4i4qjNnzCMoVaqYwePgwdR1ffvCt9+av4gnJZmnlcVd4wzZoCD4/vvUi/Ofesrc97UMGGAGLUj7C/+bb5qnqu3YAXv2wJNPwnvvmetiY9MeUSpWzPkC+o8+Mo/QpGjZEtauNdvz5pn73LkTtmyBvXvNU+B27HDeZ926zn1eXuZ1OAB33516atu1nD8PCxeaAefySaqefhrGj087vkgRuOUWuPNO6NPH/LxECrocDz9WU/gRkfRcTLjouG4qPDKcsMgwlhxcQkJyAicvnuRiwkWna6yS7Emcjz1PVHwU8cnXPztiSf+SNC7TmADvAEr4lSDIL4iyAWVpVLoRN5W4CTdbrt5PWiRb9O4Ns2dfe8zPP8Ndd5mndHXsaAaGxGvMCzN5curpZu+/bwabnTvNoyHpGTMG3nnHbPfvD999Z7aXLTMnC0hx6ZIZvFKCgGGYtaxda16T89hj5lGls2fhn3/MIy2dOzsHqhUrYP166NQJGjRwruPkSTMAtW6d2mcYzqf5jRplzsC2dSscOWKGkxv5FeXoUXjlFeejU+nZt8+8RkmkIFP4ERHJoqj4KP498y+7T+9m64mt7D69m3Ox50hITuBEzInrnl0v0DuQVuVbUb9UfYr7Fadx6cbcUvYWTScueUJysvlLvbu7+Yv2X3/BDz+Yv0w//LB55GfFiqtvn5BgXhNz+fLZs+ZRnqNHzWtsUk4Ti41NPwzExpoTAPz1l3na3MyZ0LOnGaZ69EgNNFu2mNfG3Habud5
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n",
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
" plt.xlabel(\"Threshold\", fontsize=16)\n",
" plt.legend(loc=\"upper left\", fontsize=16)\n",
" plt.ylim([0, 1])\n",
"\n",
"plt.figure(figsize=(10, 5))\n",
"plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n",
"plt.xlim([-700000, 700000])"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"*Raising* the threshold *increases precision* and *reduces recall*.\n",
"\n",
"Can select threshold of appropriate trade-off for problem at hand."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Note recall curve smoother than precision since recall related to actual positives and precision related to predicted positives."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## ROC curve"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Receiver operating characteristic* (ROC) curve plots *true positive rate* (i.e. recall) against the *false positive rate* for different *thresholds*."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 7 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Plot ROC curve"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:10.705549Z",
"iopub.status.busy": "2024-01-10T00:14:10.705078Z",
"iopub.status.idle": "2024-01-10T00:14:10.719896Z",
"shell.execute_reply": "2024-01-10T00:14:10.719259Z"
}
},
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve\n",
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:10.723141Z",
"iopub.status.busy": "2024-01-10T00:14:10.722682Z",
"iopub.status.idle": "2024-01-10T00:14:10.871990Z",
"shell.execute_reply": "2024-01-10T00:14:10.871269Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAn0AAAHRCAYAAAAfRQ/FAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACDZUlEQVR4nO3dd1iTV/8G8DsJGxkqKKA4UHGCgyqKVrHW1VptFcXXhavWto7WrW9dddVdba3VqjhptY7WKrV1UF83+lPQuhVxgagoGwIk5/cHJUoBDSHhCeT+XBdXk2flTqLl6znnOUcmhBAgIiIiojJNLnUAIiIiIjI8Fn1EREREJoBFHxEREZEJYNFHREREZAJY9BERERGZABZ9RERERCaARR8RERGRCWDRR0RERGQCWPQRERERmQAWfUREREQmwCiLvpSUFMycORNdunRBhQoVIJPJsHHjRq3PT0hIwIgRI+Ds7AxbW1u0b98e58+fN1xgIiIiIiNnlEXf06dP8eWXX+Lq1ato3Lhxkc5Vq9V49913ERISglGjRmHRokV4/Pgx/P39cfPmTQMlJiIiIjJuZlIHKIirqytiY2Ph4uKCc+fOoXnz5lqfu3PnTpw8eRI///wzAgICAAB9+vSBp6cnZs6ciZCQEEPFJiIiIjJaRtnSZ2lpCRcXF53O3blzJypXroyePXtqtjk7O6NPnz749ddfoVQq9RWTiIiIqNQwyqKvOC5cuIBmzZpBLs/71lq0aIG0tDTcuHFDomRERERE0jHK7t3iiI2NRdu2bfNtd3V1BQDExMTAy8urwHOVSmWelkC1Wo1nz56hYsWKkMlkhglMREREBEAIgeTkZLi5ueVrvNKHMlf0paenw9LSMt92Kysrzf7CLFiwALNnzzZYNiIiIqLXuX//PqpWrar365a5os/a2rrAcXsZGRma/YWZOnUqxo0bp3memJiIatWq4f79+7C3t9d/WCIiKnWyVWpkZKshhIBaABCAQO5jAQFALQQgAPU/+4QAEtIzoVILzY9aDcQlZ0Auk2m2JWVk4UFCOoRa4PaTVNyMS4Z7BRvIZTLkXF78c82cx0IA6n+uL4TAjbgUyGVAeRtzqEVODnXucf+cm5GllvTzK2u8qzrATC7DpZgkdG5QGVkqNdKzVGjo6gCFXIZslRpmCjmqlLeCQi6HSq1GeVtLWJvLIZfJoJDLkJqchOF93sO77/fEmpXLYGdnZ5CsZa7oy73z999yt7m5uRV6rqWlZYGthPb29iz6iIhKiFotkJKZDaEGVP8ULWp1TsGSrVYjPiUTMhk0RU1uIZR7zJMUJRLSMvHgeTpsLcygUquhzFbj+K2nuBefhroudpoC6EUhlVOE5T6/9igZAFDR1gIqkVukCaRmqkr40zBHwpMsrY+WW9oAAJ5nF7BTlvMjz/9rTjL2VmZQyHMKH7lMhsfJOY02XlUcIJfLIJchpzCSySCTAdceJaN5jQqwtzZDTEI6mtfImctXIfvn2H+u8zRFiYZu9lDIZTCTy5GtVsPRxgK2FgrNMbnXlv3zX2tzBWwsFJD9s0/zX8ggk+d8fC8fb2Wu0Pl9K5VKhISEYPDgwZrhY7duXINSqcSalcsMNqSszBV9TZo0wbFjx6BWq/P0h585cwY2Njbw9PSUMB0RUcnKUqmRlql6URgV0PKjUgnEp+b8slULAZUaUKkFstVqXIlJwrO0TFibK6BWi38KIOBqbBIszOSwszTTXCf3ugI5BZRaCPz9MBEymQwPE3KG1lSwtdAUcf80jOXJpMw2fCvUubvPtT42PjXTgEl0J5PlLUJkMpnmeXpWTmFarYINFPIXRcqLIkeGqCcpcHGwQi3ncpDLZDCTy3D1URJ8a1aAtXlO4aN4qeiCDIhNyEBjd0coZMjZ90/xlJ6pgrOdJeytzWH2z+tVcbSGraXZi9f9V6FlJpfBTFHm7iXVypUrV9CvXz9ERkYiKysLI0aMAJDT8GToGUZKddEXGxuLxMRE1KpVC+bm5gCAgIAA7Ny5E7t379bM0/f06VP8/PPPeO+99wpsySMi0qfEtCxkqtTIVqsR/TQNj5MzkK36p1tPCDxLzcSjxAwkpmchMT1L0xWn+qfVKrer79DVx2ji7qgp2FTqF8XRjcfJECJvN97LBZRaAJklUEAV1TMjK6JyC6J/t+LIZDmftzJbjRoVbSCX57QmKeQyqP/pRn2zjtNLxdaLwkumuc4/rUSyF8VYfIoSzaqX11xLLpchJiEdTd0doVDIYSaXIVstUN7GHBVsLVDJzgrO5Sxha5m3BYpKHyEEVq9ejfHjxyMjIwNOTk6am0xLitEWfd9++y0SEhIQExMDAPjtt9/w4MEDAMDo0aPh4OCAqVOnYtOmTbhz5w5q1KgBIKfoa9myJYYMGYIrV67AyckJ3333HVQqFW/SICrD0jNVSEzPwsOEdGSr1HiSokRSejauxCaivI0FnqYoce1RMqo4WiPifgKAnFan3AIrJiEdSRnZcLQx14yPEsA/47VePM8trF7sy23dyikS9C03a2Gep2nf9Sc1hTxn7FqNijaa1qN/t0DlFDXA3w+T0Ka2EyzM5JpjcrsAHyako5KdJdwcrfO0JL3cohWfkonalcqhkr0VytuYa7r5HKzNUb2iDSzN5CyeqMTExcVh2LBh2L9/PwCgc+fO2Lhxo85zEutKJoTQ//+l9KBGjRq4e/dugftyi7zBgwfnK/oA4Pnz55g4cSJ++eUXpKeno3nz5liyZAneeOONImVISkqCg4MDEhMTOaaPqAiSM7IQl5SBxPTsPAPXkzKy8DwtEwqZDH/HJMLROqerTyUEVCqBozeeoJ6rPVRqNdIzVQi7/gQAYGUuh7lC/qLg+td/sw1QbBkbs9zuMXluYSODEDljzDycbF8qeF4qfuSAWg1ciU1C+7rOeboC5f86PiYhHZXtreDmaK0prhTynCI3M1sNb3dHOFqba8ZfKeQyCAE4lbPQHP/va8v+aemyMMsptohM0aFDh9C/f388fvwYlpaWWLRoEUaNGlXglCyGrjuMtugzBiz6yBQkZWQhJSMbsYkZuPU4GTLIkKVW4/bjVJgpcsbe5BZld56m4mFCOpzKWf4z9itv4fW6ViljZKGQQy7HPwPBZUhWZqNaBRuY5fTNveiuw4uuQPxrm+ylbjwZgOtxyajrYg83ByuYKeS4/TgFvd+oCgszuaZ4y1YLOJezRDkrM1S2t9KMhXq5gFLIZHCwNjfZsU9EZcGJEyfQtm1bNGzYECEhIWjUqFGhxxq67jDa7l0i0p4QAs/TsnDi1lOkZ6n+GYQvkJWtxsUHCXB1tEZWthph1x8jNjEDNhZmeJpSnAHDyXrLXhS1K5XL1xWY252XpRK49ywNPtXLw0wuQ31XeySmZ6FO5XKwsTCDe3lrmClkqGBrCVsLBawsFLC3YusTEelfUlKSpmhr3bo19u3bh/bt22vmDJYKiz4iI5CZrcb952l4lJiBB8/TAAAq9YvpKHKLuP+7+xz3nqXhckwSAMBckVP8FHXAfpqBpp2QyXK6A4Gcbj+ncpZoVasiLM0UOS1cchkS0zJRpbw1KtpaIlOl1rSq5XYZWpkr4PjPGCxzuRz2/3QpEhEZO7VajeXLl2Pu3Lk4deoU6tWrBwDo2rWrxMlysOgjKiYhBJ6mZCLqSYqmOFOpBR4kpMP8n65R9T/bz0Q9QzkrMySmZ+HglTjYWCiKVYBlqf6ZGbYYqjhaQwiBN2pUwLPUTNRytkVDNweYm8mQkaWGe3kbTTevXJ7z3xpOtjCXyzXdnblTO3BgPBGZqocPHyIoKAiHDx8GAGzcuBFfffWVxKnyYtFH9AqPkzLwICEd95+l4dTteNhZmeHCvQTNlBjFHcOmrxa3xu6OyMxWQ6VWw9XBGu96uWoKNPk/rWS548vsrczg5mhdrIlFiYjohV27dmHEiBF49uwZbGxs8PXXX2P48OFSx8qHRR+ZtCyVGpH3E3DnaSou3E+ArYUCWSqBPy8/QkxiRonlqFbBBveepaF5jfJo4GqPCraWcHWwyjOwP3d2ekszBeq52MH1nyKOiIikkZK
"text/plain": [
"<Figure size 700x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_roc_curve(fpr, tpr, label=None):\n",
" plt.plot(fpr, tpr, linewidth=2, label=label)\n",
" plt.plot([0, 1], [0, 1], 'k--')\n",
" plt.axis([0, 1, 0, 1])\n",
" plt.xlabel('False Positive Rate (FPR)', fontsize=16)\n",
" plt.ylabel('True Positive Rate (TPR)', fontsize=16)\n",
"\n",
"plt.figure(figsize=(7, 5))\n",
"plot_roc_curve(fpr, tpr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Dashed line corresponds to random classifier.\n",
"\n",
"Again, there is a trade-off. As the threshold is reduced to increase the true positive rate, we get a larger false positive rate."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 8 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Area under the ROC curve\n",
"\n",
"Area under the ROC curve (AUC) is a common performance metric."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:10.875409Z",
"iopub.status.busy": "2024-01-10T00:14:10.874976Z",
"iopub.status.idle": "2024-01-10T00:14:10.898598Z",
"shell.execute_reply": "2024-01-10T00:14:10.898029Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"0.9668396102663849"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import roc_auc_score\n",
"roc_auc_score(y_train_5, y_scores)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 9 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Comparing classifier ROC curves"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:14:10.901977Z",
"iopub.status.busy": "2024-01-10T00:14:10.901330Z",
"iopub.status.idle": "2024-01-10T00:15:13.117633Z",
"shell.execute_reply": "2024-01-10T00:15:13.116928Z"
}
},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n",
" method=\"predict_proba\")"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:15:13.121181Z",
"iopub.status.busy": "2024-01-10T00:15:13.120525Z",
"iopub.status.idle": "2024-01-10T00:15:13.130757Z",
"shell.execute_reply": "2024-01-10T00:15:13.130130Z"
}
},
"outputs": [],
"source": [
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:15:13.133917Z",
"iopub.status.busy": "2024-01-10T00:15:13.133466Z",
"iopub.status.idle": "2024-01-10T00:15:13.315999Z",
"shell.execute_reply": "2024-01-10T00:15:13.315269Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f34fd8df7f0>"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsoAAAIeCAYAAACiOSl5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACtnklEQVR4nOzdd1xT5x4G8CcJWxRUQAEBcaDiQhy46h4468StKFatrVWr1WrrrOM6arVqba0LxYqjtlq11r0VrYC27oGbIcOwZ87945SwNUTgBHi+n09uk/PmnDxBL/7y5h0yQRAEEBERERFRFnKpAxARERER6SIWykREREREuWChTERERESUCxbKRERERES5YKFMRERERJQLFspERERERLlgoUxERERElAsWykREREREuWChTERERESUCxbKRERERES50MlCOTY2FvPmzYO7uzsqVKgAmUyGbdu2aXz+mzdvMG7cOFhaWqJMmTJo3749/P39Cy8wEREREZU4Olkoh4eHY+HChbhz5w4aNmyYr3NVKhV69OiBX375BZ9++imWL1+OsLAwtGvXDg8ePCikxERERERU0uhJHSA31tbWCA4ORuXKlfH333+jadOmGp+7b98+XLp0CXv37sWAAQMAAB4eHnBycsK8efPwyy+/FFZsIiIiIipBdLJH2dDQEJUrV9bq3H379qFSpUro16+f+pilpSU8PDxw4MABJCUlFVRMIiIiIirBdLJQfh8BAQFwdXWFXJ71rTVr1gzx8fG4f/++RMmIiIiIqDjRyaEX7yM4OBht2rTJcdza2hoA8OrVK9SvXz/Xc5OSkrL0OKtUKkRGRqJixYqQyWSFE5iIiIiItCYIAmJiYmBjY5Ojo/R9lbhCOSEhAYaGhjmOGxkZqdvzsnTpUixYsKDQshERERFR4Xj+/DmqVKlSoNcscYWysbFxruOQExMT1e15mTVrFj7//HP1Y6VSCXt7ezx//hzlypUr+LCZ86WkITQ6ESFvEvFKmYAQZRLCYhPxOiYJ4bHJCI9JRERcMlLShELNoS2FXAa5XAY9OSCXySCTAQqZDHKZDHJZ+rH/7sszHf/vvuK/c2TZ7qefK5cBMojH068vl8kAGf67/99xZD1PJgPkyLieLNvz1Dnl6a/z32vJZP9ly8iuyP4+/jtPlilj5mukvzeZTAY50p+fNZtC/ZyM9y5eM1NGiK8B/PczQPp98X/Sj/z340D6E9SP1efLMl0H6m9JZOrnyjI9N/M1Mp2X/XGm10u/fsZryrK8fubs6mPZMhARqVTA/QeAIAAmxoCDQ9b2x4+BN28AQQU0aAjo62e0PXkCBN4Q2xo3BuztM9pSUoCvvwIqWwO1agHdu2e97rx5QFgoYGIKfLsya9vn04DLl4AyZYAfNwA1ama0Xf8bGD5SvD/tc2Ds2Iy2V8FAx47iffeuwHffZb1uHWfxv2VNAT+/rL8HP/oIuHBRvO+7C8i8CNi//wIDPcT7kz4FRo4ATMtm/Pzq1hPvN2oI/LIr62uOGQNcviLe//tvoIxJRtvWLcDy/977mtVAly4ZbffuA336iPc9BgLZ+xU/7APcvw8YGQIBAVnbFiwE/jwC6OkDPjuAqlUz2i5fAf46Krb17QPUrZvRFhsDHD8OJKcAzZuLfxfS0tLw0/rvYWhoiFbtJiItFTAzi0HHljVRtmxZFLQSVyinr5iRXfoxGxubPM81NDTMtTe6XLly710oC4KAy48j4Pc4EndDonE/NBYRsUlISROQnKZCmupdBbAC0DOGXIs/MSN9OazKGqG8iT7KGunD1FAPpkZ6KGukh7Lq+xnHTQ31YKgnh55cDj2FDHpyGfQVcijkMugpZNBXH89o59AUIiLdIAg5P3gmJYmFoiAAJiaAQpHRFhUlFp9paYCdHfDfSEUAQHIy8Ouv4nnW1kD79lmve/w48PSp2D58OJC5L+rSJeDAASA1FRg2DHB1zWh7/lwsuvz9gVWrgKlTM9ri4oA6VcX7HToAJ09mfc2ZnwG//y7ef/Uqa95Th4GZk8T727YBnVpltN25A/zuK94fNAiYMj7rdf8+Dzx8CFSoADjZZW2zMAEiXwKRAKzKZW0Pug2oYsX7Ua+ytpnIMtqMkfO6bZoAZ84AyljAzkL8s0nXvT1w7i/xvm22THHhWV/zyV0gfQ0DQQBcawNyOdCgTs7X7NNVPKanB9Syy/qa/XsAlmXFvx9d2wA1M51b2Qw4sEv8u2Vvn/O6238E4uPFc7O37fwZeXKyA0YNzLvd1TnjfmhoKIYNG4mTJ09CoVBg7KjhqFmzJqKjxb94hVGLlLhC2cXFBefPn4dKpcoyTsXPzw8mJiZwcnIq0jyvY5Jw+l4YtlwIwt2QGK2uIZMBFcsYwLKsEcyN9WGgJ4e+Qg4DPRkMFOJ9fT05yhnpw6qsISz/u6XfNzXUYyFLRFSEEhPFAlGpBGxts7Y9ewaEh//X81c3a3H57Blw+bJYtLq6ArVrZ7QdPSoWf4IAtGsHfPxx1uuamAAJCYCLS84evU8+ATZvFu//8w9Qr15G27FjwODB4v3sRWtCAjB0qHi/S5echfLatcAff4j3P/ww63vx9weWLxfvN2qUtVBOTBTb09/zmzeAubn4OPMQUyGXPqS3tWduS0vL2pa5F/Pff3NeN/3Dg0qVs61GjdxfAxD/fGvVAu7dy/pzBQAzM2DRIvGcWrVyXnfy5IwCN3PBCgBDhgBNmgBVqmTtGQfE1wkJETMbG4s93elkMuDatZyvlW769LzbmjQRb7kpVw7o2TPvc/Oxkq9WTp48iWHDhiE0NBQmJib44YcfULNmzXef+J6KdaEcHBwMpVKJ6tWrQ/+/714GDBiAffv2Yf/+/ep1lMPDw7F371706tUr1x7jgpaYkob9/i/hc+UpbgdH5/ocI305rM2MYaCQw0BPDiN9OSqVM4KtuTFszI1hbWYEazNjWJUzRIUyBtBXlLgFSoiI8i29MMr82T8+Hnj5UiyMKlYELC0z2uLigO+/F9tq1hR7EjObMwcIChILDm/vrG0HDgA//igWqNmLyzNngM6dxaLqyy+BxYuznmtvD7x+LRaI2TeGnTdPLHgB4PZtoE6djLZLl8QCCQBWr84olNPSxNvu3eJjY+OchXL6zya/xaVepkoge3GZn6I1e4H5tutmbtu2Dfjf/7K2jRuXd3HZuzfg6Cj+HchcIAJA69bA+vXiuS1bZm0zMAB27hR71nMr6i7+N8whc297utmzxVtu88Tq1QPu3s15HADKlgW++ir3NiBjKENu7O1zFsjp9PWBSpXyPrckSUtLw8KFC/HNN99AEATUq1cPu3fvhrOz87tPLgA6WyivW7cOb968watXrwAAf/zxB168eAEAmDRpEszMzDBr1ix4e3sjKCgIVf/7qDhgwAA0b94co0ePxu3bt2FhYYEffvgBaWlphT5RLyYxBTuuPMWm80GIjEvO0e5iZ45RLR3gYlce9hVMoJCzl5eIpBUVJf6ja2qacUwQgFOnxK/ry5YFPvgg6zm7dolf1atUwMyZYgGS7sgRsV2lAqZMyVqQPHggfj2vUgF9+4qFR2YGBmIRU6+eWJhmNmMGsHKlmO38ebEgSnfhAtC1q3h/zhxg4cKMtri4jNfp1StnoXz4sNj7qq+fs1B+/FgskgGxB1ipFHsI06Wmiv/NrQcyvS0lBbh1K+u4y7cVl3n1iCoU4jCE7NfPrEkTcahEbp1sdeqIY2Xl8pzFZfXqwMSJ4ms0apS1zdBQHFeb/nV7dhMnAj16iO3ZRyj27g04O4vXzZ7JwQGIjhbbsvek6usDP/2U87XSjRqVd5uzs3jLjUKR0Tuem4oV824r4IUUSEOCIKB79+44duwYAGDs2LFYs2YNTLL/pSlEOlsor1y5Ek+fPlU/3r9/P/bv3w8AGD58OMwy/7bKRKFQ4MiRI/jiiy/w/fffIyEhAU2bNsW2bdtQK7ePpgXgeWQ8Nl8Iwr7rLxCblPW3V31bM7SuaYFOdazgal+eQyCISrH4eCA2FihfPusEpIsXxcIwNVX8Sj1zz2V0NDBrllggNmwIjM82rvLbb8V
"text/plain": [
"<Figure size 800x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(8, 6))\n",
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
"plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n",
"plt.legend(loc=\"lower right\", fontsize=16)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"inclass_exercise"
]
},
"source": [
"### Exercise: from the ROC curve, which method appears to work better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
},
"tags": [
"solution",
"inclass_exercise"
]
},
"source": [
"Random Forests since get closer to the ideal point (i.e. top left of plot)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Comparing metrics"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:15:13.319443Z",
"iopub.status.busy": "2024-01-10T00:15:13.318949Z",
"iopub.status.idle": "2024-01-10T00:15:13.354041Z",
"shell.execute_reply": "2024-01-10T00:15:13.353400Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(0.9983631764491033, 0.9668396102663849)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# AUC\n",
"roc_auc_score(y_train_5, y_scores_forest), roc_auc_score(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:15:13.357195Z",
"iopub.status.busy": "2024-01-10T00:15:13.356728Z",
"iopub.status.idle": "2024-01-10T00:16:15.780979Z",
"shell.execute_reply": "2024-01-10T00:16:15.780322Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(0.9890893831305078, 0.739423076923077)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Precision\n",
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
"\n",
"y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n",
"precision_score(y_train_5, y_train_pred_forest), precision_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:15.784167Z",
"iopub.status.busy": "2024-01-10T00:16:15.783687Z",
"iopub.status.idle": "2024-01-10T00:16:15.822234Z",
"shell.execute_reply": "2024-01-10T00:16:15.821584Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(0.8695812580704667, 0.8511344770337577)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Recall\n",
"recall_score(y_train_5, y_train_pred_forest), recall_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:15.825376Z",
"iopub.status.busy": "2024-01-10T00:16:15.824899Z",
"iopub.status.idle": "2024-01-10T00:16:15.865268Z",
"shell.execute_reply": "2024-01-10T00:16:15.864611Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(0.9254932757435947, 0.7913558013892462)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# F_1\n",
"f1_score(y_train_5, y_train_pred_forest), f1_score(y_train_5, y_train_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Progress so far"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So far we have considered binary classification only (e.g. five and not five).\n",
"\n",
"Clearly in many scenarios we want to classify multiple classes."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Multiclass classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Binary classifiers distinguish between two classes.\n",
"Multiclass classifiers can distinguish between more than two classes.\n",
"\n",
"Some algorithms can handle multiple classes directly (e.g. Random Forests, naive Bayes).\n",
"Others are strictly binary classifiers (e.g. Support Vector Machines, Linear classifiers)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Multiclass classification strategies\n",
"\n",
"\n",
"However, there are various strategies that can be used to perform multiclass classification with binary classifiers."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"- **One-versus-rest (OvR) / one-versus-all (OvA)**: train a binary classifier for each class, then select classification with greatest score across classifiers \n",
"<br>(e.g. train a binary classifier for each digit)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"- **One-versus-one (OvO)**: train a binary classifier for each pair of classes, then select classification that wins most duels \n",
"<br>(e.g. train a binary classifier for each pairs of digits: 0 vs 1, 0 vs 2, ..., 1 vs 2, 1 vs 3, ...)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Comparison of multiclass classification strategies"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**One-versus-rest (OvR)**:\n",
"- $N$ classifiers for $N$ classes\n",
"- each classifier uses all of the training data\n",
"\n",
"$\\Rightarrow$ requires training relatively *few classifiers* but training each classifier can be *slow*."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"**One-versus-one (OvO)**:\n",
"- $N(N-1)/2$ classifiers for $N$ classes\n",
"- each classifier uses a subset of the training data (typically much smaller than overall training dataset)\n",
"\n",
"$\\Rightarrow$ requires training *many* classifiers but training each classifier can be *fast*."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Preferred approach\n",
"\n",
"OvR usually preferred, unless training binary classifier is very slow with large data-sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"In Scikit-Learn, if try to use binary classifier for a multiclass classification problem, OvR is automatically run."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:15.869360Z",
"iopub.status.busy": "2024-01-10T00:16:15.868886Z",
"iopub.status.idle": "2024-01-10T00:16:24.929227Z",
"shell.execute_reply": "2024-01-10T00:16:24.928575Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([5])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_clf.fit(X_train, y_train)\n",
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Can see OvR\n",
"performed by inspecting scores, where we have a score per classifier. \n",
"\n",
"The 5th score (starting from 0) is clearly largest."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:24.933068Z",
"iopub.status.busy": "2024-01-10T00:16:24.932355Z",
"iopub.status.idle": "2024-01-10T00:16:24.939500Z",
"shell.execute_reply": "2024-01-10T00:16:24.938897Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ -81742.80600673, -226403.87485225, -310042.19433877,\n",
" -173577.43026798, -82855.74343468, 39922.35938292,\n",
" -183200.20815396, 10437.2327332 , -240036.68142135,\n",
" -160691.66786235]])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
"some_digit_scores"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:24.942384Z",
"iopub.status.busy": "2024-01-10T00:16:24.941943Z",
"iopub.status.idle": "2024-01-10T00:16:24.948594Z",
"shell.execute_reply": "2024-01-10T00:16:24.947920Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_clf.classes_"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:24.951594Z",
"iopub.status.busy": "2024-01-10T00:16:24.951202Z",
"iopub.status.idle": "2024-01-10T00:16:24.958059Z",
"shell.execute_reply": "2024-01-10T00:16:24.957454Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_clf.classes_[np.argmax(some_digit_scores)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### OvO with Scikit-Learn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Can also perform OvO multiclass classification."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:24.961009Z",
"iopub.status.busy": "2024-01-10T00:16:24.960645Z",
"iopub.status.idle": "2024-01-10T00:16:35.489682Z",
"shell.execute_reply": "2024-01-10T00:16:35.489026Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(array([5]), 45)"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.multiclass import OneVsOneClassifier\n",
"ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42, max_iter=10))\n",
"ovo_clf.fit(X_train, y_train)\n",
"ovo_clf.predict([some_digit]), len(ovo_clf.estimators_)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Many classifiers can inherently classify multiple classes\n",
"\n",
"Random Forest can directly classify multiple classes so OvR or OvO classification not required."
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:16:35.493180Z",
"iopub.status.busy": "2024-01-10T00:16:35.492766Z",
"iopub.status.idle": "2024-01-10T00:17:18.031732Z",
"shell.execute_reply": "2024-01-10T00:17:18.031034Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([5])"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forest_clf.fit(X_train, y_train)\n",
"forest_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:17:18.035061Z",
"iopub.status.busy": "2024-01-10T00:17:18.034601Z",
"iopub.status.idle": "2024-01-10T00:17:18.046924Z",
"shell.execute_reply": "2024-01-10T00:17:18.046315Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.03, 0.01, 0. , 0.02, 0.02, 0.85, 0.02, 0.03, 0.01, 0.01]])"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forest_clf.predict_proba([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:17:18.050000Z",
"iopub.status.busy": "2024-01-10T00:17:18.049472Z",
"iopub.status.idle": "2024-01-10T00:17:37.039538Z",
"shell.execute_reply": "2024-01-10T00:17:37.038826Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([0.8687 , 0.86975, 0.8449 ])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Error analysis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute confusion matrix for multiclass classification."
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:17:37.044920Z",
"iopub.status.busy": "2024-01-10T00:17:37.043335Z",
"iopub.status.idle": "2024-01-10T00:17:55.983368Z",
"shell.execute_reply": "2024-01-10T00:17:55.982681Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[5765, 2, 16, 16, 7, 21, 33, 9, 48, 6],\n",
" [ 1, 6470, 41, 17, 9, 35, 13, 16, 138, 2],\n",
" [ 79, 52, 5155, 153, 39, 31, 171, 85, 181, 12],\n",
" [ 80, 40, 224, 5087, 19, 286, 62, 78, 217, 38],\n",
" [ 37, 30, 46, 15, 5101, 22, 64, 75, 315, 137],\n",
" [ 127, 27, 35, 208, 48, 4451, 175, 35, 259, 56],\n",
" [ 45, 10, 57, 3, 19, 112, 5611, 11, 45, 5],\n",
" [ 25, 32, 91, 49, 50, 15, 8, 5824, 81, 90],\n",
" [ 70, 175, 127, 278, 43, 394, 86, 51, 4571, 56],\n",
" [ 48, 33, 54, 117, 326, 99, 5, 801, 834, 3632]])"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train, cv=3)\n",
"conf_mx = confusion_matrix(y_train, y_train_pred)\n",
"conf_mx"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 10 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Performance analysis can provide insight into how to make improvements.\n",
"\n",
"For example, for the previous dataset, one might want to consider trying to improve the performane of classifying 9 by collecting more training data for 7s and 9s."
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}