2307 lines
281 KiB
Plaintext
2307 lines
281 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"# Lecture 4: Performance analysis\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"[Run in colab](https://colab.research.google.com/drive/1qRS7NIEctu9MaqafI5IpK9RWG3yyMYO1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:45.881490Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:45.881060Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:45.891554Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:45.891079Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Last executed: 2024-01-10 00:13:45\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import datetime\n",
|
||
|
"now = datetime.datetime.now()\n",
|
||
|
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"## Examining datasets\n",
|
||
|
"\n",
|
||
|
"Use the MNIST digit dataset as a worked example in this lecture."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Fetch MNIST data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:45.930175Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:45.929703Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:45.985455Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:45.984920Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Common imports\n",
|
||
|
"import os\n",
|
||
|
"import numpy as np\n",
|
||
|
"np.random.seed(42) # To make this notebook's output stable across runs"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:45.990862Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:45.989504Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:58.720117Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:58.719391Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Fetch MNIST dataset\n",
|
||
|
"from sklearn.datasets import fetch_openml\n",
|
||
|
"#mnist = fetch_openml('mnist_784')\n",
|
||
|
"mnist = fetch_openml('mnist_784', parser=\"pandas\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Extract features and targets"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"MNIST dataset is already split into standard training set (first 60,000 images) and test set (last 10,000 images)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:58.723827Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:58.723238Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:58.739013Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:58.738377Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"((60000,), (10000,))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_train = mnist.target[:60000].to_numpy(dtype=int)\n",
|
||
|
"y_test = mnist.target[-10000:].to_numpy(dtype=int)\n",
|
||
|
"y_train.shape, y_test.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:58.742189Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:58.741593Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:58.747855Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:58.747206Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"((60000, 784), (10000, 784))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X_train = mnist.data[:60000].to_numpy()\n",
|
||
|
"X_test = mnist.data[-10000:].to_numpy()\n",
|
||
|
"X_train.shape, X_test.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Each datum corresponds to a 28 x 28 image."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:58.750740Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:58.750288Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:58.754000Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:58.753467Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"28.0 28\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import math\n",
|
||
|
"n_float = np.sqrt(X_train.shape[1])\n",
|
||
|
"n = math.floor(n_float)\n",
|
||
|
"print(n_float, n)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Plot image of digit"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:58.757013Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:58.756399Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:59.203839Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:59.203168Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIO0lEQVR4nO3cr2+XVwOH4acLVFDExJio2choCA0K/gBmtmQhmVmCWVKzZRmuyRxhbmoKhwCzMDGxP2BBAQJB0KzZNBMDSSsqeF53m3fmPKE/AtflPzlPkyb395izMs/zPAHANE3vHfUHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgA5cdQfwLvj4cOHi3ZPnz4d3vzwww+LzmKZ/f394c3e3t6is168eDG8WVtbG96sr68Pb94GbgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACAexGORJY+Sff311wfwJf9tc3PzUM45c+bM8ObevXuLznr06NHwZnd3d3izsrIyvDnMB/H+/fff4c3p06eHN5988snw5quvvhreTNM0/fjjj4t2B8FNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4LHrc7tq1a8Obf/75Z3gzTdM0z/Pw5urVq4dyzpLH45acc5hnLTlnY2NjeLO6ujq8maZpWl9fH95cuXJleLPk4b3DeojxILkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvLfPs2bPhzeeffz68WfK43Ycffji8maZpunTp0vDmvffGf++8fv16eLPkobVTp04Nbw77rFHnzp07lHM4eG4KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAVuZ5no/6I3hzNjc3hzd//vnn8Oazzz4b3ty6dWt4M03L/iZgGTcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQE0f9Afy327dvL9otedzu008/Hd7cv39/eLPU7u7u8ObRo0fDm1OnTg1vrly5MryB48xNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4x9TOzs6hnfXgwYPhzcrKypv/EN64ra2t4c2dO3eGN6urq8Mbjic3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkJV5nuej/gj+37Nnzxbtfv755+HNy5cvF5016oMPPli0u3z58hv+kqO1t7e3aPfTTz8Nb169ejW8+fvvv4c3586dG95wPLkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBAPDtnz588X7c6fPz+8OXny5PDmr7/+Gt6cOXNmeMPx5KYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDkxFF/ALxrbty4sWi3t7c3vPnmm2+GN148fbe5KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgKzM8zwf9UfAcfD8+fPhzfb29vDm999/H95M0zR9/PHHw5snT54MbzyI925zUwAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCADlx1B8AB+Hhw4fDm+vXrw9vdnZ2hjenT58e3kzTNN29e3d443E7RrkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvEDx+/Hh4s7m5ueis999/f9Fu1P7+/vDm6dOni8769ddfhze3b99edNaoL774Ynjzyy+/LDrL43YcBjcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQlXme56P+iLfd2bNnhzc3b95cdNb6+vrw5sWLF8Ob3377bXjzxx9/DG+maZqW/ItevHhxeLO9vT282draGt6cPHlyeAOHxU0BgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFADIiaP+gHfBhQsXhjfffvvtAXzJm7O6ujq82djYWHTWd999N7z5/vvvhzdra2vDG3jbuCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCszPM8H/VHvO329/eHN/fu3Vt01t7e3qLdqC+//HJ489FHHx3AlwBvkpsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIB/EAiJsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAg/wMoh+Lly7Gm9AAAAABJRU5ErkJggg==",
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%matplotlib inline\n",
|
||
|
"import matplotlib\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"plt.rcParams['axes.labelsize'] = 14\n",
|
||
|
"plt.rcParams['xtick.labelsize'] = 12\n",
|
||
|
"plt.rcParams['ytick.labelsize'] = 12\n",
|
||
|
"\n",
|
||
|
"some_digit = X_train[41000]\n",
|
||
|
"some_digit_image = some_digit.reshape(n, n)\n",
|
||
|
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
|
||
|
" interpolation=\"nearest\")\n",
|
||
|
"plt.axis(\"off\");"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Plot selection of digits"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:59.207815Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:59.207477Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:13:59.659642Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:13:59.658941Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Extract digits\n",
|
||
|
"n_digits = 10\n",
|
||
|
"n_images = 10\n",
|
||
|
"example_images = np.zeros([n_images * n_digits, n*n])\n",
|
||
|
"for i in range(n_digits):\n",
|
||
|
" example_images[i*n_images:(i+1)*n_images,:] = X_train[np.where(y_train == i)][0:n_images,:]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:13:59.663112Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:13:59.662527Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:02.060635Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:02.059958Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x1000 with 0 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAn8AAAJ8CAYAAACP2sdVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAADTXElEQVR4nOydZ7gTVReFX2yggKgoCKJgBTsq6GfvBRWx995774q9V1QUFRW72BV77x3Fgih2LGBvFCvc74fPuieZm9ybnklmvX9ybzJJzs5MJnPWWXvvVg0NDQ0YY4wxxphEMF21B2CMMcYYYyqHL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxKEL/6MMcYYYxLEDNUeQC68+eabAAwePBiAG264AYBddtkFgIMOOgiAZZddtgqjqyzPPvssAGuttRYADQ0Nafevvvrq1RiWKSGPPPIIABtttBEA88wzDwBXX301AH369AFgrrnmqsLojDG58tVXXwFwxBFHAHDnnXc22eawww4D4KKLLqrcwEzisfJnjDHGGJMgWjVIOooZb7/9duPfa665JgC///57xm07dOgAwM8//1z2cVWL66+/HoBLL70UgPfeew+AqVOnAtC7d28gqKEHHHAAADPMUBPiLgBvvPFG49/LL788AK1atcq47WmnnQbAiSeeWP6BVRgpf5tssknGx/v37w/APffcU7ExlZp3330XgJdeeqnxvv333z+n584222wAvPzyywAsuuiipR1ciXn11VcBGDVqFAAXXHABAJ999lnjNklS7n/44YfGv/faay8ARowYkbZNrX6/pfRtvfXWQNj3mZh33nnTnnP77benPbfW6du3LwAjR44EYLrpsmtNiyyyCAA333wzAN27dwdgzjnnLOcQK8Lzzz/f+PeXX34JwFlnnQXABx98AITVnXPOOQeATTfdFIB27dqVbVxW/owxxhhjEkTslL/XX38dgC222KLxvm+++QYIKtCss84KwEwzzQTAjz/+CMCLL74IwHLLLZf2eC0jxe/GG28E4IUXXkh7XMrf9NNPn3b/J598AoQZVC2w8cYbN/798MMPA03jiiK1SMfLaqutVqbRlYdff/0VSFe9NFP87rvvMj7n1FNPBeD4448v7+BKyMcffwzAvffeC8CQIUMA+OKLLxq3yabyZmPxxRcH4NprrwWCWlxttAJx4IEHAvDUU08B6YpXlNlnnx2Abt26ZXz8jDPOAGDllVcGYI455ijNYCuI9vljjz3WeN9DDz3U7HO00rHffvuVb2Al4PDDDwfg4osvTrt/q622AuDCCy8EgtoHcMcddwCwzTbbpD2n1j2AOn/tuuuuAIwbNw5o/lwe/R3bfvvtgaCS16IC+PjjjwNhNQ6yn9Oj7L333gBccsklALRu3brEo7PyZ4wxxhiTKKqu/E2ZMgWAt956C4Add9wRCD4ICBmtUgak7B199NFAmDlpO82Sa0UZkfojn+Nuu+3W+JjUgr/++ivtOb169QLCjEnKiqgF5U+qz/rrrw/At99+2/jYxIkTgZaVP8UvhUBqS62gWXLqPv/888+BlmO/6667ABgwYECZRlc88umut956QFD2RerpJ1/lT5x33nkAHHnkkQU9v9Rov+Tj3Yqe47Kx2WabAXDLLbcA0KZNm0KGWBHkydSxre+3vtvQ8jGu+OQB1Pd7xhlnLO1g80S/T1Ji9f///vc/IKh2K664YouvNd9886W9hpBKeOihhwJBXYwrOpfJbz527Fgg++pUKtm2ee211wBYZpllSjvYCjDzzDMDTX+780EVTnL1Q+eDlT9jjDHGmARR9VTQffbZB4Bbb7015+eo7t+kSZOAkCGnjDllwsad++67Dwj125544gkgzIIg+2zpqKOOAmDatGlAyJqrJf755x8gPesxafzyyy9AUMDzQd8dZdEpC7iaSMWWWvHAAw8AIc58kM9FfrhUZTiOvPLKK0B5v4vyTMoLdMwxx5TtvQrlueeeA2DbbbcFgie7ECZPngyEVR4hX1ylkTontU7I01eIOrflllsC8PXXXwNBPVRtQN3q8bh5AZWxKo+fxlkMWrXr2bNn0a9VaZrz9ubKUkstBcDmm29e9Gtlw8qfMcYYY0yCqJryJ/XuwQcfBNK9PwBrrLFG49/KApWnp2vXrkDwAUgZeOaZZzK+VtxQLaOdd9454+Op409VAbNt09x2cebkk0/OeduhQ4cCoWbUVVddVZYxlZtDDjkEgMsuuyzrNtq3Le1TZY4pmy4O3H333UDITi+G+eefHwjKR1zVbfkaN9xwQwB+++23vF9D5zgppPLLZUP1wOSvKmc9sHyROlmM4pcNrZJUS/mTx08om7cYP55ikWImn6BeM5pJLE+pasZVm8UWWwzIXsdPq1PNEd3m9NNPB0L9P2X/1gLHHnssUJzXr3379gDMPffcJRlTJqz8GWOMMcYkiIorf8poXWeddYAwa1aWm2bPt912W+Nz5OU788wzAdhzzz2B0Nt06aWXTnsN1Y5SBnFcev5K8ZP6Iz+fMto6deoEBC8jNO1aom01M9Dn11LWXByI9qxtDlX2V5afULxSxeKu8kbRMZpL5puyRVdZZRUg1HhUZw9tp//lsapGTawJEyYAoe92S5xwwgkALLDAAo33Pfroo0Dof3rccccB8Pfff5dsnOXgp59+AnJX/FR/NFUtUsV/fd+1iiG1U15KoffKRVWpFFJrlHmfjVxWNqLbajudG8eMGQME1ancyM8pz58UP9XqKwZl9abWAExFHj/VgJQSLoVQY6sWqsH44YcfAqHebpRM5zxlr+s4vv/++9Mel49Qv3vl9MAVy/Dhw4HcOi/16NEDSK9zWmms/BljjDHGJIiKKX8fffQREGpyaeYq9a5Lly5AqIad6mGRHya1A0RzKHNS1cHzySQuB8rqlccvOgNSZwJ1AlBXD2jqc5JCoBlQ6ra1Qi4qZVTxE1HlLO5eR3WnUQ2sm266Ke1x9alNVetUx/Lyyy8HQr2obFlkUgT1naqG8ifF/p133km7X/tJHSlUr0rZm4oNQr1CKUgLLrggENQeKYPyFcaFgQMH5rW9PEGnnHJKk8f0OaljjfzNOn70fVcNtei5pZrIw5trbc58ttV2yvi+5pprgMplvka7cJRC8csXKcXqFyyFXMpfLjUFy4EyzqXSZUMrGBA8fOphq8oF2WiuL3BckHezuRWAFVZYAYBhw4YB4bNTVYRKEv9P1BhjjDHGlIyyK3/KeFGmrvx46s+rrMA+ffoA8Mcff5TsvaMV0yuNZumqeSbkX5Di11zmp+r9aFYV7XGpGlHKgnvjjTeKGnM5aSm7VypHJuT7KkcGYTmQAqaONe+//z7QVOnYaaedgOYVDHVvkeobJ1SBX11JokjJyqWnpbaN9q2V362QLNpyIsVF2ZfZkMKn88Ass8yS83tIzdGt/FRS/tQDVFmRqhFXSaRuZ0Pqtr7DqR0+xAwz/PdTdPDBBwMhLtUMjCpD8pbKM7bqqqsWMvSc0W9JNl9eJZHqqFUQZQFXS/nr168fELpKRVl00UWBdD9w9HOUgp1NCZZCGCekwKoSRfS3SRVLUrvwKJehQ4cOAOywww5AOM7lade5Thng2Xp+F4OVP2OMMcaYBFF25U8Zt1L8hLJ61J2jHpFvTVXqhaqXK5sxSqo3QrOqzp07Z9xW3sg49/gUUndHjRqV8XGpl5mQOhpH9SsT8mgpA64YpIhK/chW41DKqrLKK8HZZ58NhBlrlFL0pHzyySfTbuOC/MvRbGStaqjKwO677552fynRe2ssuWQaloqouh1Fip+OV3lWo6sXEBS/c889N+1+ecWj6HgrpDNOMVRDWc2GuorI+1ct9Hsmv36U0aNHt/ga6vmt1bAoffv2BeK1sqXvdfQcr17W+j1rzteqig6q26lqKPrMRowYAbi3rzHGGGOMKZK
|
||
|
"text/plain": [
|
||
|
"<Figure size 800x800 with 100 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Plot digits\n",
|
||
|
"plt.figure(figsize=(10,10))\n",
|
||
|
"fig, axes = plt.subplots(n_digits, n_images, figsize=(8, 8),\n",
|
||
|
" subplot_kw={'xticks':[], 'yticks':[]},\n",
|
||
|
" gridspec_kw=dict(hspace=0.1, wspace=0.1))\n",
|
||
|
"\n",
|
||
|
"for i, ax in enumerate(axes.flat):\n",
|
||
|
" ax.imshow(example_images[i].reshape(n,n), cmap='binary', interpolation='nearest')\n",
|
||
|
" ax.axis(\"off\")\n",
|
||
|
" ax.text(0.05, 0.05, str(i // n_images),\n",
|
||
|
" transform=ax.transAxes, color='green')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:02.064046Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:02.063359Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:02.067767Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:02.067210Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def plot_digit(data):\n",
|
||
|
" image = data.reshape(n, n)\n",
|
||
|
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
|
||
|
" interpolation=\"nearest\")\n",
|
||
|
" plt.axis(\"off\");"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Shuffle training data so not ordered by type."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:02.070642Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:02.070400Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:02.736323Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:02.735619Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Shuffle training data\n",
|
||
|
"import numpy as np\n",
|
||
|
"shuffle_index = np.random.permutation(60000)\n",
|
||
|
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"## Binary classifier\n",
|
||
|
"\n",
|
||
|
"Construct a classify to distinguish between 5s and all other digits."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:02.739912Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:02.739322Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:02.742873Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:02.742160Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_train_5 = (y_train == 5)\n",
|
||
|
"y_test_5 = (y_test == 5)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:02.745912Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:02.745363Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:02.751897Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:02.751306Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([False, False, False, ..., False, True, False])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_test_5"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Train\n",
|
||
|
"\n",
|
||
|
"Train a linear model using Stochastic Gradient Descent (good for large data-sets, as we will see later...)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:02.754951Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:02.754519Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:02.758982Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:02.758368Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# disable convergence warning from early stopping\n",
|
||
|
"from warnings import simplefilter\n",
|
||
|
"from sklearn.exceptions import ConvergenceWarning\n",
|
||
|
"simplefilter(\"ignore\", category=ConvergenceWarning)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:02.761746Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:02.761309Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.112020Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.111257Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-r
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"SGDClassifier(max_iter=10, random_state=42)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.linear_model import SGDClassifier\n",
|
||
|
"sgd_clf = SGDClassifier(random_state=42, max_iter=10);\n",
|
||
|
"sgd_clf.fit(X_train, y_train_5)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Recall extracted `some_digit` previously, which was a 5."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.115229Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.114730Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.162849Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.162180Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIO0lEQVR4nO3cr2+XVwOH4acLVFDExJio2choCA0K/gBmtmQhmVmCWVKzZRmuyRxhbmoKhwCzMDGxP2BBAQJB0KzZNBMDSSsqeF53m3fmPKE/AtflPzlPkyb395izMs/zPAHANE3vHfUHAHB8iAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgA5cdQfwLvj4cOHi3ZPnz4d3vzwww+LzmKZ/f394c3e3t6is168eDG8WVtbG96sr68Pb94GbgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACAexGORJY+Sff311wfwJf9tc3PzUM45c+bM8ObevXuLznr06NHwZnd3d3izsrIyvDnMB/H+/fff4c3p06eHN5988snw5quvvhreTNM0/fjjj4t2B8FNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4LHrc7tq1a8Obf/75Z3gzTdM0z/Pw5urVq4dyzpLH45acc5hnLTlnY2NjeLO6ujq8maZpWl9fH95cuXJleLPk4b3DeojxILkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvLfPs2bPhzeeffz68WfK43Ycffji8maZpunTp0vDmvffGf++8fv16eLPkobVTp04Nbw77rFHnzp07lHM4eG4KAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAVuZ5no/6I3hzNjc3hzd//vnn8Oazzz4b3ty6dWt4M03L/iZgGTcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQE0f9Afy327dvL9otedzu008/Hd7cv39/eLPU7u7u8ObRo0fDm1OnTg1vrly5MryB48xNAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoAxIN4x9TOzs6hnfXgwYPhzcrKypv/EN64ra2t4c2dO3eGN6urq8Mbjic3BQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkJV5nuej/gj+37Nnzxbtfv755+HNy5cvF5016oMPPli0u3z58hv+kqO1t7e3aPfTTz8Nb169ejW8+fvvv4c3586dG95wPLkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBAPDtnz588X7c6fPz+8OXny5PDmr7/+Gt6cOXNmeMPx5KYAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgDkxFF/ALxrbty4sWi3t7c3vPnmm2+GN148fbe5KQAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgKzM8zwf9UfAcfD8+fPhzfb29vDm999/H95M0zR9/PHHw5snT54MbzyI925zUwAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCADlx1B8AB+Hhw4fDm+vXrw9vdnZ2hjenT58e3kzTNN29e3d443E7RrkpABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAeBDvEDx+/Hh4s7m5ueis999/f9Fu1P7+/vDm6dOni8769ddfhze3b99edNaoL774Ynjzyy+/LDrL43YcBjcFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQlXme56P+iLfd2bNnhzc3b95cdNb6+vrw5sWLF8Ob3377bXjzxx9/DG+maZqW/ItevHhxeLO9vT282draGt6cPHlyeAOHxU0BgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFADIiaP+gHfBhQsXhjfffvvtAXzJm7O6ujq82djYWHTWd999N7z5/vvvhzdra2vDG3jbuCkAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCszPM8H/VHvO329/eHN/fu3Vt01t7e3qLdqC+//HJ489FHHx3AlwBvkpsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIB/EAiJsCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAg/wMoh+Lly7Gm9AAAAABJRU5ErkJggg==",
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plot_digit(some_digit)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Predict class:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.166557Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.166129Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.173666Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.173133Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([ True])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"some_digit.shape\n",
|
||
|
"\n",
|
||
|
"sgd_clf.predict([some_digit])\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Test accuracy"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.178073Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.176751Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.213684Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.213035Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_test = sgd_clf.predict(X_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.218677Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.217226Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.225740Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.225190Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.9601"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import accuracy_score\n",
|
||
|
"accuracy_score(y_test, y_test_5)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Cross-validation"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"#### n-fold cross-validation\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture04_Images/5-fold-CV.png\" width=\"700\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"[Image credit: [VanderPlas](https://github.com/jakevdp/PythonDataScienceHandbook)]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercises 2-3 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Consider naive classifier \n",
|
||
|
"\n",
|
||
|
"Classify everying as not 5."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.229463Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.229066Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.234410Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.233859Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.base import BaseEstimator\n",
|
||
|
"class Never5Classifier(BaseEstimator):\n",
|
||
|
" def fit(self, X, y=None):\n",
|
||
|
" pass\n",
|
||
|
" def predict(self, X):\n",
|
||
|
" return np.zeros((len(X), 1), dtype=bool)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"What accuracy expect?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.237953Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.237493Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:04.867261Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:04.866601Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([0.909 , 0.90745, 0.9125 ])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import cross_val_score\n",
|
||
|
"\n",
|
||
|
"never_5_clf = Never5Classifier()\n",
|
||
|
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Need to go beyond classification accuracy, especially for skewed datasets."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"## Confusion matrix\n",
|
||
|
"\n",
|
||
|
"Can gain further insight into performance by examining confusion matrix."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Confusion matrix shows true/false-positive/negative classifications\n",
|
||
|
"\n",
|
||
|
"- True-positive $\\text{TP}$: number of true positives (i.e. *correctly classified* as *positive*)\n",
|
||
|
"- False-positive $\\text{FP}$: number of false positives (i.e. *incorrectly classified* as *positive*)\n",
|
||
|
"- True-negative $\\text{TN}$: number of true negatives (i.e. *correctly classified* as *negative*)\n",
|
||
|
"- False-negative $\\text{FN}$: number of false negatives (i.e. *incorrectly classified* as *negative*)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"<table>\n",
|
||
|
" <tr>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td>Predicted</td>\n",
|
||
|
" <td></td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td>Negative</td>\n",
|
||
|
" <td>Positive</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>Actual</td>\n",
|
||
|
" <td>Negative</td>\n",
|
||
|
" <td>TN</td>\n",
|
||
|
" <td>FP</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td>Positive</td>\n",
|
||
|
" <td>FN</td>\n",
|
||
|
" <td>TP</td>\n",
|
||
|
" </tr>\n",
|
||
|
"</table>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Cross-validation prediction\n",
|
||
|
"\n",
|
||
|
"`cross_val_predict` performs K-fold cross-validation, returing predictions made on each test fold. Get clean prediction on each test fold, i.e. clean prediction for each instance in the training set."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:04.871959Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:04.870596Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:07.774513Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:07.773793Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import cross_val_predict\n",
|
||
|
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Compute confusion matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:07.778441Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:07.778158Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:07.793206Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:07.792567Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([[52953, 1626],\n",
|
||
|
" [ 807, 4614]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import confusion_matrix\n",
|
||
|
"conf_matrix = confusion_matrix(y_train_5, y_train_pred)\n",
|
||
|
"conf_matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Each row represents actual class, while each colum represents predicted class."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Perfect confusion matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:07.796975Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:07.796729Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:07.810580Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:07.809959Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([[54579, 0],\n",
|
||
|
" [ 0, 5421]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_train_perfect_predictions = y_train_5\n",
|
||
|
"confusion_matrix(y_train_5, y_train_perfect_predictions)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 4 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"## Precision and recall"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"- **Precision**: of predicted positives, proportion that are correctly classified (also called *positive predictive value*).\n",
|
||
|
"\n",
|
||
|
"$$\\text{precision} = \\frac{\\text{TP}}{\\text{TP} + \\text{FP}}$$\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"- **Recall**: of actual positives, proportion that are correctly classified (also called *true positive rate* or *sensitivity*).\n",
|
||
|
"\n",
|
||
|
"$$\\text{recall} = \\frac{\\text{TP}}{\\text{TP} + \\text{FN}}$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "-"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Remember:\n",
|
||
|
"<table>\n",
|
||
|
" <tr>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td>Predicted</td>\n",
|
||
|
" <td></td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td>Negative</td>\n",
|
||
|
" <td>Positive</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td>Actual</td>\n",
|
||
|
" <td>Negative</td>\n",
|
||
|
" <td>TN</td>\n",
|
||
|
" <td>FP</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <td></td>\n",
|
||
|
" <td>Positive</td>\n",
|
||
|
" <td>FN</td>\n",
|
||
|
" <td>TP</td>\n",
|
||
|
" </tr>\n",
|
||
|
"</table>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 5 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### $F_1$ score"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"$F_1$ score is the *harmonic mean* of the precision and recall."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"$$F_1 = \\frac{2}{1/\\text{precision} + 1/\\text{recall}} = 2 \\frac{\\text{precision} \\times \\text{recall}}{\\text{precision} + \\text{recall}} = \\frac{\\text{TP}}{\\text{TP} + \\frac{\\text{FN}+\\text{FP}}{2}}$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"$F_1$ favours classifiers that have similar (and high) precision and recall.\n",
|
||
|
"\n",
|
||
|
"Sometimes may wish to favour precision or recall."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 6 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Precision-recall tradeoff"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Under the hood the classifier computes a *score*. Binary decision is then made depending on whether score exceeds some *threshold*.\n",
|
||
|
"\n",
|
||
|
"By changing the threshold, one can change the tradeoff between\n",
|
||
|
"precision and recall."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Scikit-Learn does not let you set the threshold directly but can access scores (confidence score for a sample is, e.g., the signed distance of that sample to classifying hyperplane)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:07.814908Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:07.814670Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:07.821596Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:07.820991Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([24299.40524152])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 25,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_scores = sgd_clf.decision_function([some_digit])\n",
|
||
|
"y_scores"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Can then make prediction for given threshold."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:07.825101Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:07.824869Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:07.831310Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:07.830702Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([ True])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"threshold = 0\n",
|
||
|
"y_some_digit_pred = (y_scores > threshold)\n",
|
||
|
"y_some_digit_pred"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:07.834535Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:07.834282Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:07.840676Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:07.840078Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([False])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"threshold = 200000\n",
|
||
|
"y_some_digit_pred = (y_scores > threshold)\n",
|
||
|
"y_some_digit_pred"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"#### Compute precision and recall for range of thresholds"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:07.844026Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:07.843796Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:10.475531Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:10.474826Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n",
|
||
|
" method=\"decision_function\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:10.479280Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:10.479036Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:10.494777Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:10.494077Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import precision_recall_curve\n",
|
||
|
"precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:10.499634Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:10.497954Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:10.701765Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:10.701097Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(-700000.0, 700000.0)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 30,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAz8AAAHRCAYAAABEhAeYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACCoklEQVR4nO3dZ3gUVR+G8XvTCySUEAgl9CZFOiIdpYsKUqQoiIpYABEUe8eCCBZsoFJERUGwgEivUhWQLkhLgNAhBdJ33g/zZsOSQBJIMtnk+XHtlTNnzsz+sxuSPJmZMzbDMAxERERERETyOTerCxAREREREckNCj8iIiIiIlIgKPyIiIiIiEiBoPAjIiIiIiIFgsKPiIiIiIgUCAo/IiIiIiJSICj8iIiIiIhIgaDwIyIiIiIiBYLCj4iIiIiIFAgKPyIiIiIiUiBkOfzExMTwyiuv0KlTJ4oVK4bNZmPatGmZ3v7ChQsMGTKEEiVK4O/vT9u2bdmyZUtWyxAREREREcmSLIefM2fO8Prrr7Nnzx5uvvnmLG1rt9vp2rUr3333HU888QTjxo3j1KlTtGnThv3792e1FBERERERkUzzyOoGISEhREREUKpUKf766y8aN26c6W3nzJnDunXrmD17Nj179gSgd+/eVKtWjVdeeYXvvvsuq+WIiIiIiIhkSpaP/Hh7e1OqVKnrerI5c+ZQsmRJevTo4egrUaIEvXv35pdffiE+Pv669isiIiIiIpKRXJ3wYOvWrTRo0AA3N+enbdKkCZcuXWLfvn25WY6IiIiIiBQgWT7t7UZERETQqlWrNP0hISEAHD9+nDp16qRZHx8f73RUyG63c+7cOYoXL47NZsu5gkVEREREJE8zDIPo6GhKly6d5iDLlXI1/MTGxuLt7Z2m38fHx7E+PW+//TavvfZajtYmIiIiIiKuKzw8nLJly15zTK6GH19f33Sv64mLi3OsT89zzz3HU0895ViOjIwkNDSUz5Z9ho+/D4ZhYGA4fQTS7b/8Y3xSPBcTLzr1Xbkd4NS+1vr0tk+0JxKdEI3dsDv6UtoYpNsfERPB2UtnKVmopGNfdsPuXH86/YnJiUQlRBGXFMfFhIvEJ+kaqvzA090TL3cv3GxuuNvccXMzPwZ6B+Lj4YObzc3xsNlsjrafpx+lCpUiNCCUUoVKUdyvOD4ePni6eVLUtygB3gEU9S2Kr4ev0z5S9iMiIjfu/Hnw9gY/v+zb58SJ8OqrZvujj+D229OOKVUK3N1Tl6OjISoq4317ekJwsHPfmTOQmcuy/f2hSJHUZcOA48cz3g6geHH4/9/CAYiNhXPnMrdtmTLOy+fPw6VLGW/n42M+7+VOnoSkpIy3DQyEQoUyV5/kvKioKMqVK0fhwoUzHJur4SdlprgrpfSVLl063e28vb3TPWLUr1E/AgICsrfIfCQuKQ67YQdwCmopy5e3s3tdTEIMcUlxaQJnmhB3RfAzDIP45Hii46NJsidhN+wkG8nYDbvZticTFR/l+NxS9pneIyo+ikuJl0i0JxKbGMuFuAsk2hMdz3flw8AgLimOmIQYEpMTHc+bbE8m2UgmNjGWRHtiDr5j6Uv8/z8H8y3l9KXTOfq8Nswg5enuSaB3IN4e3rjb3PFw88DT3ZPivsUp4lOE4n7FCfAKoLB3YZqXa04hr0J4unvi4eZBUZ+i+Hv54+nmSSGvQnh7pP1/LCKSnz3zDHzxhRkKihSBgACoXNkMLO3aQfXqkNW/NwUHQ6VKkJAANWuaj4wEBKQNCZl1I79qBQZe/3OWLHn9214v/Vrp2jLzx9tcDT/16tVjzZo12O12p/PxNm7ciJ+fH9WqVcvNcvI9Hw+fjAdJlqSErsvDWEo7NimWQ+cPkWwkk5icSKI9kQtxF4hLiiPJnuQIUQnJCVyIu0BCcoKjL+VjZHykGdiSE4lLiiMyPpL4pPg0QexS4iWi4qPSDXHZxcAwa0tKJi4pLlv26eXuhZe7F55u5tGsMgFlKFWoFH6efo6+Yr7FKORViCC/ICoXrew4SlXEpwiebp54unsS4B2Ah1uufvsSkTzg0iUzSCQnQ8WKULUq1KgBXl7pjz93DsLCzLCR8gfh+fPhl19g1Chz2xS7d8Nzz5lHaFL2W7v2tfcP5hGVv/82w0i5cmmDTHi4+fHCBfMBsH07zJtntoOCoHt3GD0aMvtr0KOPmg8Ryboc++0hIiKCyMhIKleujKenJwA9e/Zkzpw5zJ0713GfnzNnzjB79my6deuW7tEdkbwko0AZGhiaS5VcXcqRrYiYCMIjwzkefZyImAjHEa2o+CguxF0gKiGKc7HnnALclY9kI/VIW7I92QxxRjKRcZEkG8lZri0hOYGE5ATH8smLJ6/78/Tx8MHL3QsfDx+K+BTB39MfP08/ivsVx9fDl6I+RalYtCIl/UsS5BdEi9AWBPpc558gRSRDdjukd52xYZiBIOVjipTTkjJ7OphhmKdVXcnDwwworVrBJ5+k9k+cCJedMU+ZMhASAn/9ZS4//LDzfrp3h/QmnfX0hHr1zP2/8QZcfoa+YUDbtmb4AfPIzoULMGwYDBgATZpA3brm6WYnTpgfT582t0tx5gxMmQLNmjmHnwMHYNIkGDjQfH4RyR424/KLWTJp0qRJXLhwgePHj/PZZ5/Ro0cP6tevD8CwYcMIDAxk0KBBTJ8+nUOHDlGhQgUAkpOTadGiBTt37uTpp58mKCiITz/9lLCwMDZv3kz16tUz9fxRUVEEBgYSGRmp095ELGAYBrFJsZy5dIbzsefZfXo3x6OPczb2LInJiSTZk4hJiOFC/AUSkxOJTTJPO4xPiifRbq4/F3uOM5fO5HrtXu5euNvccXczT+HzcveiuG9xivkWo0KRClQvXp06JevQIKRBngizInmNYcA335jXlHTokNrfrBls2GC2U0LAlU6fNo90AEybBg89BDfdBA0bQqNG5jUYW7fCl19C+/bQrRsMGWKOf+klePPNq9fVrBmsW5e6nNHZLw88AF9/nfnxXl5mYLv8OpqdOyGdSWoB8yhQWFja/uRk83NctAhWrYLVq83raebPh65dU8fNn29+/mAGr759zYAWHJz10+RE8rusZIPrCj8VKlTgyJEj6a5LCTvphR+A8+fP8/TTT/Pzzz8TGxtL48aNGT9+PI0aNcr08yv8iOQPsYmxRMZHOgLThbgLRCdEEx0fzf5z+zkRc4KLCRc5G3uWhOQEEu2JXEy46DhtMCE5gfNx57mYcNE8OnUdR6My4uvhi7eHN1WKVaFqsaqEFAoh2D+Y8kXKE1IohMrFKlM24Nozy4i4qjNnzCMoVaqYwePgwdR1ffvCt9+av4gnJZmnlcVd4wzZoCD4/vvUi/Ofesrc97UMGGAGLUj7C/+bb5qnqu3YAXv2wJNPwnvvmetiY9MeUSpWzPkC+o8+Mo/QpGjZEtauNdvz5pn73LkTtmyBvXvNU+B27HDeZ926zn1eXuZ1OAB33516atu1nD8PCxeaAefySaqefhrGj087vkgRuOUWuPNO6NPH/LxECrocDz9WU/gRkfRcTLjouG4qPDKcsMgwlhxcQkJyAicvnuRiwkWna6yS7Emcjz1PVHwU8cnXPztiSf+SNC7TmADvAEr4lSDIL4iyAWVpVLoRN5W4CTdbrt5PWiRb9O4Ns2dfe8zPP8Ndd5mndHXsaAaGxGvMCzN5curpZu+/bwabnTvNoyHpGTMG3nnHbPfvD999Z7aXLTMnC0hx6ZIZvFKCgGGYtaxda16T89hj5lGls2fhn3/MIy2dOzsHqhUrYP166NQJGjRwruPkSTMAtW6d2mcYzqf5jRplzsC2dSscOWKGkxv5FeXoUXjlFeejU+nZt8+8RkmkIFP4ERHJoqj4KP498y+7T+9m64mt7D69m3Ox50hITuBEzInrnl0v0DuQVuVbUb9UfYr7Fadx6cbcUvYWTScueUJysvlLvbu7+Yv2X3/BDz+Yv0w//LB55GfFiqtvn5BgXhNz+fLZs+ZRnqNHzWtsUk4Ti41NPwzExpoTAPz1l3na3MyZ0LOnGaZ69EgNNFu2mNfG3Habud5
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x500 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):\n",
|
||
|
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
|
||
|
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
|
||
|
" plt.xlabel(\"Threshold\", fontsize=16)\n",
|
||
|
" plt.legend(loc=\"upper left\", fontsize=16)\n",
|
||
|
" plt.ylim([0, 1])\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(10, 5))\n",
|
||
|
"plot_precision_recall_vs_threshold(precisions, recalls, thresholds)\n",
|
||
|
"plt.xlim([-700000, 700000])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"*Raising* the threshold *increases precision* and *reduces recall*.\n",
|
||
|
"\n",
|
||
|
"Can select threshold of appropriate trade-off for problem at hand."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Note recall curve smoother than precision since recall related to actual positives and precision related to predicted positives."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"## ROC curve"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"*Receiver operating characteristic* (ROC) curve plots *true positive rate* (i.e. recall) against the *false positive rate* for different *thresholds*."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 7 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Plot ROC curve"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:10.705549Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:10.705078Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:10.719896Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:10.719259Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import roc_curve\n",
|
||
|
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:10.723141Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:10.722682Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:10.871990Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:10.871269Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAn0AAAHRCAYAAAAfRQ/FAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACDZUlEQVR4nO3dd1iTV/8G8DsJGxkqKKA4UHGCgyqKVrHW1VptFcXXhavWto7WrW9dddVdba3VqjhptY7WKrV1UF83+lPQuhVxgagoGwIk5/cHJUoBDSHhCeT+XBdXk2flTqLl6znnOUcmhBAgIiIiojJNLnUAIiIiIjI8Fn1EREREJoBFHxEREZEJYNFHREREZAJY9BERERGZABZ9RERERCaARR8RERGRCWDRR0RERGQCWPQRERERmQAWfUREREQmwCiLvpSUFMycORNdunRBhQoVIJPJsHHjRq3PT0hIwIgRI+Ds7AxbW1u0b98e58+fN1xgIiIiIiNnlEXf06dP8eWXX+Lq1ato3Lhxkc5Vq9V49913ERISglGjRmHRokV4/Pgx/P39cfPmTQMlJiIiIjJuZlIHKIirqytiY2Ph4uKCc+fOoXnz5lqfu3PnTpw8eRI///wzAgICAAB9+vSBp6cnZs6ciZCQEEPFJiIiIjJaRtnSZ2lpCRcXF53O3blzJypXroyePXtqtjk7O6NPnz749ddfoVQq9RWTiIiIqNQwyqKvOC5cuIBmzZpBLs/71lq0aIG0tDTcuHFDomRERERE0jHK7t3iiI2NRdu2bfNtd3V1BQDExMTAy8urwHOVSmWelkC1Wo1nz56hYsWKkMlkhglMREREBEAIgeTkZLi5ueVrvNKHMlf0paenw9LSMt92Kysrzf7CLFiwALNnzzZYNiIiIqLXuX//PqpWrar365a5os/a2rrAcXsZGRma/YWZOnUqxo0bp3memJiIatWq4f79+7C3t9d/WCIiKnWyVWpkZKshhIBaABCAQO5jAQFALQQgAPU/+4QAEtIzoVILzY9aDcQlZ0Auk2m2JWVk4UFCOoRa4PaTVNyMS4Z7BRvIZTLkXF78c82cx0IA6n+uL4TAjbgUyGVAeRtzqEVODnXucf+cm5GllvTzK2u8qzrATC7DpZgkdG5QGVkqNdKzVGjo6gCFXIZslRpmCjmqlLeCQi6HSq1GeVtLWJvLIZfJoJDLkJqchOF93sO77/fEmpXLYGdnZ5CsZa7oy73z999yt7m5uRV6rqWlZYGthPb29iz6iIhKiFotkJKZDaEGVP8ULWp1TsGSrVYjPiUTMhk0RU1uIZR7zJMUJRLSMvHgeTpsLcygUquhzFbj+K2nuBefhroudpoC6EUhlVOE5T6/9igZAFDR1gIqkVukCaRmqkr40zBHwpMsrY+WW9oAAJ5nF7BTlvMjz/9rTjL2VmZQyHMKH7lMhsfJOY02XlUcIJfLIJchpzCSySCTAdceJaN5jQqwtzZDTEI6mtfImctXIfvn2H+u8zRFiYZu9lDIZTCTy5GtVsPRxgK2FgrNMbnXlv3zX2tzBWwsFJD9s0/zX8ggk+d8fC8fb2Wu0Pl9K5VKhISEYPDgwZrhY7duXINSqcSalcsMNqSszBV9TZo0wbFjx6BWq/P0h585cwY2Njbw9PSUMB0RUcnKUqmRlql6URgV0PKjUgnEp+b8slULAZUaUKkFstVqXIlJwrO0TFibK6BWi38KIOBqbBIszOSwszTTXCf3ugI5BZRaCPz9MBEymQwPE3KG1lSwtdAUcf80jOXJpMw2fCvUubvPtT42PjXTgEl0J5PlLUJkMpnmeXpWTmFarYINFPIXRcqLIkeGqCcpcHGwQi3ncpDLZDCTy3D1URJ8a1aAtXlO4aN4qeiCDIhNyEBjd0coZMjZ90/xlJ6pgrOdJeytzWH2z+tVcbSGraXZi9f9V6FlJpfBTFHm7iXVypUrV9CvXz9ERkYiKysLI0aMAJDT8GToGUZKddEXGxuLxMRE1KpVC+bm5gCAgIAA7Ny5E7t379bM0/f06VP8/PPPeO+99wpsySMi0qfEtCxkqtTIVqsR/TQNj5MzkK36p1tPCDxLzcSjxAwkpmchMT1L0xWn+qfVKrer79DVx2ji7qgp2FTqF8XRjcfJECJvN97LBZRaAJklUEAV1TMjK6JyC6J/t+LIZDmftzJbjRoVbSCX57QmKeQyqP/pRn2zjtNLxdaLwkumuc4/rUSyF8VYfIoSzaqX11xLLpchJiEdTd0doVDIYSaXIVstUN7GHBVsLVDJzgrO5Sxha5m3BYpKHyEEVq9ejfHjxyMjIwNOTk6am0xLitEWfd9++y0SEhIQExMDAPjtt9/w4MEDAMDo0aPh4OCAqVOnYtOmTbhz5w5q1KgBIKfoa9myJYYMGYIrV67AyckJ3333HVQqFW/SICrD0jNVSEzPwsOEdGSr1HiSokRSejauxCaivI0FnqYoce1RMqo4WiPifgKAnFan3AIrJiEdSRnZcLQx14yPEsA/47VePM8trF7sy23dyikS9C03a2Gep2nf9Sc1hTxn7FqNijaa1qN/t0DlFDXA3w+T0Ka2EyzM5JpjcrsAHyako5KdJdwcrfO0JL3cohWfkonalcqhkr0VytuYa7r5HKzNUb2iDSzN5CyeqMTExcVh2LBh2L9/PwCgc+fO2Lhxo85zEutKJoTQ//+l9KBGjRq4e/dugftyi7zBgwfnK/oA4Pnz55g4cSJ++eUXpKeno3nz5liyZAneeOONImVISkqCg4MDEhMTOaaPqAiSM7IQl5SBxPTsPAPXkzKy8DwtEwqZDH/HJMLROqerTyUEVCqBozeeoJ6rPVRqNdIzVQi7/gQAYGUuh7lC/qLg+td/sw1QbBkbs9zuMXluYSODEDljzDycbF8qeF4qfuSAWg1ciU1C+7rOeboC5f86PiYhHZXtreDmaK0prhTynCI3M1sNb3dHOFqba8ZfKeQyCAE4lbPQHP/va8v+aemyMMsptohM0aFDh9C/f388fvwYlpaWWLRoEUaNGlXglCyGrjuMtugzBiz6yBQkZWQhJSMbsYkZuPU4GTLIkKVW4/bjVJgpcsbe5BZld56m4mFCOpzKWf4z9itv4fW6ViljZKGQQy7HPwPBZUhWZqNaBRuY5fTNveiuw4uuQPxrm+ylbjwZgOtxyajrYg83ByuYKeS4/TgFvd+oCgszuaZ4y1YLOJezRDkrM1S2t9KMhXq5gFLIZHCwNjfZsU9EZcGJEyfQtm1bNGzYECEhIWjUqFGhxxq67jDa7l0i0p4QAs/TsnDi1lOkZ6n+GYQvkJWtxsUHCXB1tEZWthph1x8jNjEDNhZmeJpSnAHDyXrLXhS1K5XL1xWY252XpRK49ywNPtXLw0wuQ31XeySmZ6FO5XKwsTCDe3lrmClkqGBrCVsLBawsFLC3YusTEelfUlKSpmhr3bo19u3bh/bt22vmDJYKiz4iI5CZrcb952l4lJiBB8/TAAAq9YvpKHKLuP+7+xz3nqXhckwSAMBckVP8FHXAfpqBpp2QyXK6A4Gcbj+ncpZoVasiLM0UOS1cchkS0zJRpbw1KtpaIlOl1rSq5XYZWpkr4PjPGCxzuRz2/3QpEhEZO7VajeXLl2Pu3Lk4deoU6tWrBwDo2rWrxMlysOgjKiYhBJ6mZCLqSYqmOFOpBR4kpMP8n65R9T/bz0Q9QzkrMySmZ+HglTjYWCiKVYBlqf6ZGbYYqjhaQwiBN2pUwLPUTNRytkVDNweYm8mQkaWGe3kbTTevXJ7z3xpOtjCXyzXdnblTO3BgPBGZqocPHyIoKAiHDx8GAGzcuBFfffWVxKnyYtFH9AqPkzLwICEd95+l4dTteNhZmeHCvQTNlBjFHcOmrxa3xu6OyMxWQ6VWw9XBGu96uWoKNPk/rWS548vsrczg5mhdrIlFiYjohV27dmHEiBF49uwZbGxs8PXXX2P48OFSx8qHRR+ZtCyVGpH3E3DnaSou3E+ArYUCWSqBPy8/QkxiRonlqFbBBveepaF5jfJo4GqPCraWcHWwyjOwP3d2ekszBeq52MH1nyKOiIikkZK
|
||
|
"text/plain": [
|
||
|
"<Figure size 700x500 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def plot_roc_curve(fpr, tpr, label=None):\n",
|
||
|
" plt.plot(fpr, tpr, linewidth=2, label=label)\n",
|
||
|
" plt.plot([0, 1], [0, 1], 'k--')\n",
|
||
|
" plt.axis([0, 1, 0, 1])\n",
|
||
|
" plt.xlabel('False Positive Rate (FPR)', fontsize=16)\n",
|
||
|
" plt.ylabel('True Positive Rate (TPR)', fontsize=16)\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(7, 5))\n",
|
||
|
"plot_roc_curve(fpr, tpr)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Dashed line corresponds to random classifier.\n",
|
||
|
"\n",
|
||
|
"Again, there is a trade-off. As the threshold is reduced to increase the true positive rate, we get a larger false positive rate."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 8 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Area under the ROC curve\n",
|
||
|
"\n",
|
||
|
"Area under the ROC curve (AUC) is a common performance metric."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:10.875409Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:10.874976Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:14:10.898598Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:14:10.898029Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.9668396102663849"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 33,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.metrics import roc_auc_score\n",
|
||
|
"roc_auc_score(y_train_5, y_scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 9 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Comparing classifier ROC curves"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:14:10.901977Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:14:10.901330Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:15:13.117633Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:15:13.116928Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
||
|
"forest_clf = RandomForestClassifier(random_state=42)\n",
|
||
|
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n",
|
||
|
" method=\"predict_proba\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 35,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:15:13.121181Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:15:13.120525Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:15:13.130757Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:15:13.130130Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
|
||
|
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:15:13.133917Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:15:13.133466Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:15:13.315999Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:15:13.315269Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<matplotlib.legend.Legend at 0x7f34fd8df7f0>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 36,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsoAAAIeCAYAAACiOSl5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACtnklEQVR4nOzdd1xT5x4G8CcJWxRUQAEBcaDiQhy46h4468StKFatrVWr1WrrrOM6arVqba0LxYqjtlq11r0VrYC27oGbIcOwZ87945SwNUTgBHi+n09uk/PmnDxBL/7y5h0yQRAEEBERERFRFnKpAxARERER6SIWykREREREuWChTERERESUCxbKRERERES5YKFMRERERJQLFspERERERLlgoUxERERElAsWykREREREuWChTERERESUCxbKRERERES50MlCOTY2FvPmzYO7uzsqVKgAmUyGbdu2aXz+mzdvMG7cOFhaWqJMmTJo3749/P39Cy8wEREREZU4Olkoh4eHY+HChbhz5w4aNmyYr3NVKhV69OiBX375BZ9++imWL1+OsLAwtGvXDg8ePCikxERERERU0uhJHSA31tbWCA4ORuXKlfH333+jadOmGp+7b98+XLp0CXv37sWAAQMAAB4eHnBycsK8efPwyy+/FFZsIiIiIipBdLJH2dDQEJUrV9bq3H379qFSpUro16+f+pilpSU8PDxw4MABJCUlFVRMIiIiIirBdLJQfh8BAQFwdXWFXJ71rTVr1gzx8fG4f/++RMmIiIiIqDjRyaEX7yM4OBht2rTJcdza2hoA8OrVK9SvXz/Xc5OSkrL0OKtUKkRGRqJixYqQyWSFE5iIiIiItCYIAmJiYmBjY5Ojo/R9lbhCOSEhAYaGhjmOGxkZqdvzsnTpUixYsKDQshERERFR4Xj+/DmqVKlSoNcscYWysbFxruOQExMT1e15mTVrFj7//HP1Y6VSCXt7ezx//hzlypUr+LCZ86WkITQ6ESFvEvFKmYAQZRLCYhPxOiYJ4bHJCI9JRERcMlLShELNoS2FXAa5XAY9OSCXySCTAQqZDHKZDHJZ+rH/7sszHf/vvuK/c2TZ7qefK5cBMojH068vl8kAGf67/99xZD1PJgPkyLieLNvz1Dnl6a/z32vJZP9ly8iuyP4+/jtPlilj5mukvzeZTAY50p+fNZtC/ZyM9y5eM1NGiK8B/PczQPp98X/Sj/z340D6E9SP1efLMl0H6m9JZOrnyjI9N/M1Mp2X/XGm10u/fsZryrK8fubs6mPZMhARqVTA/QeAIAAmxoCDQ9b2x4+BN28AQQU0aAjo62e0PXkCBN4Q2xo3BuztM9pSUoCvvwIqWwO1agHdu2e97rx5QFgoYGIKfLsya9vn04DLl4AyZYAfNwA1ama0Xf8bGD5SvD/tc2Ds2Iy2V8FAx47iffeuwHffZb1uHWfxv2VNAT+/rL8HP/oIuHBRvO+7C8i8CNi//wIDPcT7kz4FRo4ATMtm/Pzq1hPvN2oI/LIr62uOGQNcviLe//tvoIxJRtvWLcDy/977mtVAly4ZbffuA336iPc9BgLZ+xU/7APcvw8YGQIBAVnbFiwE/jwC6OkDPjuAqlUz2i5fAf46Krb17QPUrZvRFhsDHD8OJKcAzZuLfxfS0tLw0/rvYWhoiFbtJiItFTAzi0HHljVRtmxZFLQSVyinr5iRXfoxGxubPM81NDTMtTe6XLly710oC4KAy48j4Pc4EndDonE/NBYRsUlISROQnKZCmupdBbAC0DOGXIs/MSN9OazKGqG8iT7KGunD1FAPpkZ6KGukh7Lq+xnHTQ31YKgnh55cDj2FDHpyGfQVcijkMugpZNBXH89o59AUIiLdIAg5P3gmJYmFoiAAJiaAQpHRFhUlFp9paYCdHfDfSEUAQHIy8Ouv4nnW1kD79lmve/w48PSp2D58OJC5L+rSJeDAASA1FRg2DHB1zWh7/lwsuvz9gVWrgKlTM9ri4oA6VcX7HToAJ09mfc2ZnwG//y7ef/Uqa95Th4GZk8T727YBnVpltN25A/zuK94fNAiYMj7rdf8+Dzx8CFSoADjZZW2zMAEiXwKRAKzKZW0Pug2oYsX7Ua+ytpnIMtqMkfO6bZoAZ84AyljAzkL8s0nXvT1w7i/xvm22THHhWV/zyV0gfQ0DQQBcawNyOdCgTs7X7NNVPKanB9Syy/qa/XsAlmXFvx9d2wA1M51b2Qw4sEv8u2Vvn/O6238E4uPFc7O37fwZeXKyA0YNzLvd1TnjfmhoKIYNG4mTJ09CoVBg7KjhqFmzJqKjxb94hVGLlLhC2cXFBefPn4dKpcoyTsXPzw8mJiZwcnIq0jyvY5Jw+l4YtlwIwt2QGK2uIZMBFcsYwLKsEcyN9WGgJ4e+Qg4DPRkMFOJ9fT05yhnpw6qsISz/u6XfNzXUYyFLRFSEEhPFAlGpBGxts7Y9ewaEh//X81c3a3H57Blw+bJYtLq6ArVrZ7QdPSoWf4IAtGsHfPxx1uuamAAJCYCLS84evU8+ATZvFu//8w9Qr15G27FjwODB4v3sRWtCAjB0qHi/S5echfLatcAff4j3P/ww63vx9weWLxfvN2qUtVBOTBTb09/zmzeAubn4OPMQUyGXPqS3tWduS0vL2pa5F/Pff3NeN/3Dg0qVs61GjdxfAxD/fGvVAu7dy/pzBQAzM2DRIvGcWrVyXnfy5IwCN3PBCgBDhgBNmgBVqmTtGQfE1wkJETMbG4s93elkMuDatZyvlW769LzbmjQRb7kpVw7o2TPvc/Oxkq9WTp48iWHDhiE0NBQmJib44YcfULNmzXef+J6KdaEcHBwMpVKJ6tWrQ/+/714GDBiAffv2Yf/+/ep1lMPDw7F371706tUr1x7jgpaYkob9/i/hc+UpbgdH5/ocI305rM2MYaCQw0BPDiN9OSqVM4KtuTFszI1hbWYEazNjWJUzRIUyBtBXlLgFSoiI8i29MMr82T8+Hnj5UiyMKlYELC0z2uLigO+/F9tq1hR7EjObMwcIChILDm/vrG0HDgA//igWqNmLyzNngM6dxaLqyy+BxYuznmtvD7x+LRaI2TeGnTdPLHgB4PZtoE6djLZLl8QCCQBWr84olNPSxNvu3eJjY+OchXL6zya/xaVepkoge3GZn6I1e4H5tutmbtu2Dfjf/7K2jRuXd3HZuzfg6Cj+HchcIAJA69bA+vXiuS1bZm0zMAB27hR71nMr6i7+N8whc297utmzxVtu88Tq1QPu3s15HADKlgW++ir3NiBjKENu7O1zFsjp9PWBSpXyPrckSUtLw8KFC/HNN99AEATUq1cPu3fvhrOz87tPLgA6WyivW7cOb968watXrwAAf/zxB168eAEAmDRpEszMzDBr1ix4e3sjKCgIVf/7qDhgwAA0b94co0ePxu3bt2FhYYEffvgBaWlphT5RLyYxBTuuPMWm80GIjEvO0e5iZ45RLR3gYlce9hVMoJCzl5eIpBUVJf6ja2qacUwQgFOnxK/ry5YFPvgg6zm7dolf1atUwMyZYgGS7sgRsV2lAqZMyVqQPHggfj2vUgF9+4qFR2YGBmIRU6+eWJhmNmMGsHKlmO38ebEgSnfhAtC1q3h/zhxg4cKMtri4jNfp1StnoXz4sNj7qq+fs1B+/FgskgGxB1ipFHsI06Wmiv/NrQcyvS0lBbh1K+u4y7cVl3n1iCoU4jCE7NfPrEkTcahEbp1sdeqIY2Xl8pzFZfXqwMSJ4ms0apS1zdBQHFeb/nV7dhMnAj16iO3ZRyj27g04O4vXzZ7JwQGIjhbbsvek6usDP/2U87XSjRqVd5uzs3jLjUKR0Tuem4oV824r4IUUSEOCIKB79+44duwYAGDs2LFYs2YNTLL/pSlEOlsor1y5Ek+fPlU/3r9/P/bv3w8AGD58OMwy/7bKRKFQ4MiRI/jiiy/w/fffIyEhAU2bNsW2bdtQK7ePpgXgeWQ8Nl8Iwr7rLxCblPW3V31bM7SuaYFOdazgal+eQyCISrH4eCA2FihfPusEpIsXxcIwNVX8Sj1zz2V0NDBrllggNmwIjM82rvLbb8V
|
||
|
"text/plain": [
|
||
|
"<Figure size 800x600 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.figure(figsize=(8, 6))\n",
|
||
|
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
|
||
|
"plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n",
|
||
|
"plt.legend(loc=\"lower right\", fontsize=16)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"inclass_exercise"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"### Exercise: from the ROC curve, which method appears to work better?"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
},
|
||
|
"tags": [
|
||
|
"solution",
|
||
|
"inclass_exercise"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"Random Forests since get closer to the ideal point (i.e. top left of plot)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Comparing metrics"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:15:13.319443Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:15:13.318949Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:15:13.354041Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:15:13.353400Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0.9983631764491033, 0.9668396102663849)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 37,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# AUC\n",
|
||
|
"roc_auc_score(y_train_5, y_scores_forest), roc_auc_score(y_train_5, y_scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 38,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:15:13.357195Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:15:13.356728Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:15.780979Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:15.780322Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0.9890893831305078, 0.739423076923077)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 38,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Precision\n",
|
||
|
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
|
||
|
"\n",
|
||
|
"y_train_pred_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3)\n",
|
||
|
"precision_score(y_train_5, y_train_pred_forest), precision_score(y_train_5, y_train_pred)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 39,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:15.784167Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:15.783687Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:15.822234Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:15.821584Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0.8695812580704667, 0.8511344770337577)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 39,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Recall\n",
|
||
|
"recall_score(y_train_5, y_train_pred_forest), recall_score(y_train_5, y_train_pred)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 40,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:15.825376Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:15.824899Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:15.865268Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:15.864611Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0.9254932757435947, 0.7913558013892462)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 40,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# F_1\n",
|
||
|
"f1_score(y_train_5, y_train_pred_forest), f1_score(y_train_5, y_train_pred)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Progress so far"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"So far we have considered binary classification only (e.g. five and not five).\n",
|
||
|
"\n",
|
||
|
"Clearly in many scenarios we want to classify multiple classes."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"## Multiclass classification"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Binary classifiers distinguish between two classes.\n",
|
||
|
"Multiclass classifiers can distinguish between more than two classes.\n",
|
||
|
"\n",
|
||
|
"Some algorithms can handle multiple classes directly (e.g. Random Forests, naive Bayes).\n",
|
||
|
"Others are strictly binary classifiers (e.g. Support Vector Machines, Linear classifiers)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Multiclass classification strategies\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"However, there are various strategies that can be used to perform multiclass classification with binary classifiers."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"- **One-versus-rest (OvR) / one-versus-all (OvA)**: train a binary classifier for each class, then select classification with greatest score across classifiers \n",
|
||
|
"<br>(e.g. train a binary classifier for each digit)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"- **One-versus-one (OvO)**: train a binary classifier for each pair of classes, then select classification that wins most duels \n",
|
||
|
"<br>(e.g. train a binary classifier for each pairs of digits: 0 vs 1, 0 vs 2, ..., 1 vs 2, 1 vs 3, ...)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Comparison of multiclass classification strategies"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"**One-versus-rest (OvR)**:\n",
|
||
|
"- $N$ classifiers for $N$ classes\n",
|
||
|
"- each classifier uses all of the training data\n",
|
||
|
"\n",
|
||
|
"$\\Rightarrow$ requires training relatively *few classifiers* but training each classifier can be *slow*."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"**One-versus-one (OvO)**:\n",
|
||
|
"- $N(N-1)/2$ classifiers for $N$ classes\n",
|
||
|
"- each classifier uses a subset of the training data (typically much smaller than overall training dataset)\n",
|
||
|
"\n",
|
||
|
"$\\Rightarrow$ requires training *many* classifiers but training each classifier can be *fast*."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Preferred approach\n",
|
||
|
"\n",
|
||
|
"OvR usually preferred, unless training binary classifier is very slow with large data-sets."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "fragment"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"In Scikit-Learn, if try to use binary classifier for a multiclass classification problem, OvR is automatically run."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 41,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:15.869360Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:15.868886Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:24.929227Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:24.928575Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([5])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"sgd_clf.fit(X_train, y_train)\n",
|
||
|
"sgd_clf.predict([some_digit])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Can see OvR\n",
|
||
|
"performed by inspecting scores, where we have a score per classifier. \n",
|
||
|
"\n",
|
||
|
"The 5th score (starting from 0) is clearly largest."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 42,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:24.933068Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:24.932355Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:24.939500Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:24.938897Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([[ -81742.80600673, -226403.87485225, -310042.19433877,\n",
|
||
|
" -173577.43026798, -82855.74343468, 39922.35938292,\n",
|
||
|
" -183200.20815396, 10437.2327332 , -240036.68142135,\n",
|
||
|
" -160691.66786235]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 42,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"some_digit_scores = sgd_clf.decision_function([some_digit])\n",
|
||
|
"some_digit_scores"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 43,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:24.942384Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:24.941943Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:24.948594Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:24.947920Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 43,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"sgd_clf.classes_"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 44,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:24.951594Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:24.951202Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:24.958059Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:24.957454Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"5"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 44,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"sgd_clf.classes_[np.argmax(some_digit_scores)]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### OvO with Scikit-Learn"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Can also perform OvO multiclass classification."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 45,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:24.961009Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:24.960645Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:16:35.489682Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:16:35.489026Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(array([5]), 45)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 45,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.multiclass import OneVsOneClassifier\n",
|
||
|
"ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42, max_iter=10))\n",
|
||
|
"ovo_clf.fit(X_train, y_train)\n",
|
||
|
"ovo_clf.predict([some_digit]), len(ovo_clf.estimators_)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Many classifiers can inherently classify multiple classes\n",
|
||
|
"\n",
|
||
|
"Random Forest can directly classify multiple classes so OvR or OvO classification not required."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 46,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:16:35.493180Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:16:35.492766Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:17:18.031732Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:17:18.031034Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([5])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 46,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"forest_clf.fit(X_train, y_train)\n",
|
||
|
"forest_clf.predict([some_digit])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 47,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:17:18.035061Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:17:18.034601Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:17:18.046924Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:17:18.046315Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([[0.03, 0.01, 0. , 0.02, 0.02, 0.85, 0.02, 0.03, 0.01, 0.01]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 47,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"forest_clf.predict_proba([some_digit])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 48,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:17:18.050000Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:17:18.049472Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:17:37.039538Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:17:37.038826Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([0.8687 , 0.86975, 0.8449 ])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 48,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring=\"accuracy\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"### Error analysis"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Compute confusion matrix for multiclass classification."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 49,
|
||
|
"metadata": {
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2024-01-10T00:17:37.044920Z",
|
||
|
"iopub.status.busy": "2024-01-10T00:17:37.043335Z",
|
||
|
"iopub.status.idle": "2024-01-10T00:17:55.983368Z",
|
||
|
"shell.execute_reply": "2024-01-10T00:17:55.982681Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([[5765, 2, 16, 16, 7, 21, 33, 9, 48, 6],\n",
|
||
|
" [ 1, 6470, 41, 17, 9, 35, 13, 16, 138, 2],\n",
|
||
|
" [ 79, 52, 5155, 153, 39, 31, 171, 85, 181, 12],\n",
|
||
|
" [ 80, 40, 224, 5087, 19, 286, 62, 78, 217, 38],\n",
|
||
|
" [ 37, 30, 46, 15, 5101, 22, 64, 75, 315, 137],\n",
|
||
|
" [ 127, 27, 35, 208, 48, 4451, 175, 35, 259, 56],\n",
|
||
|
" [ 45, 10, 57, 3, 19, 112, 5611, 11, 45, 5],\n",
|
||
|
" [ 25, 32, 91, 49, 50, 15, 8, 5824, 81, 90],\n",
|
||
|
" [ 70, 175, 127, 278, 43, 394, 86, 51, 4571, 56],\n",
|
||
|
" [ 48, 33, 54, 117, 326, 99, 5, 801, 834, 3632]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 49,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train, cv=3)\n",
|
||
|
"conf_mx = confusion_matrix(y_train, y_train_pred)\n",
|
||
|
"conf_mx"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 10 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"Performance analysis can provide insight into how to make improvements.\n",
|
||
|
"\n",
|
||
|
"For example, for the previous dataset, one might want to consider trying to improve the performane of classifying 9 by collecting more training data for 7s and 9s."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"celltoolbar": "Slideshow",
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.18"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|