spce0038-machine-learning-w.../week3/lectures/Lecture08_LogisticRegression.ipynb

1209 lines
464 KiB
Plaintext
Raw Normal View History

2025-02-22 19:17:44 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Lecture 8: Logistic regression"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"![](https://www.tensorflow.org/images/colab_logo_32px.png)\n",
"[Run in colab](https://colab.research.google.com/drive/1NSiPlX-azox9No8o7PgcF4o70bwMUH3A )"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:56.779513Z",
"iopub.status.busy": "2024-01-10T00:19:56.779124Z",
"iopub.status.idle": "2024-01-10T00:19:56.787591Z",
"shell.execute_reply": "2024-01-10T00:19:56.787055Z"
},
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last executed: 2024-01-10 00:19:56\n"
]
}
],
"source": [
"import datetime\n",
"now = datetime.datetime.now()\n",
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Estimating probabilities\n",
"\n",
"Estimate the probability of an instance belonging to a particular class.\n",
"\n",
"Can adapt linear regression algorithm for this purpose to perform *logistic regression*."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Sigmoid function\n",
"\n",
"Consider linear weighted sum of inputs $\\theta^{\\rm T} x$ again but then apply sigmoid function $\\sigma$:\n",
"\n",
"$$\\hat{p} = h_\\theta(x) = \\sigma(\\theta^{\\rm T} x), $$\n",
"\n",
"where\n",
"\n",
"$$\\sigma(t) = \\frac{1}{1+\\exp{(-t)}}. $$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"inclass_exercise"
]
},
"source": [
"#### Plot the sigmoid function"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:56.827687Z",
"iopub.status.busy": "2024-01-10T00:19:56.827194Z",
"iopub.status.idle": "2024-01-10T00:19:57.238220Z",
"shell.execute_reply": "2024-01-10T00:19:57.237523Z"
},
"tags": [
"solution",
"inclass_exercise"
]
},
"outputs": [],
"source": [
"import numpy as np\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:57.241818Z",
"iopub.status.busy": "2024-01-10T00:19:57.241170Z",
"iopub.status.idle": "2024-01-10T00:19:57.489933Z",
"shell.execute_reply": "2024-01-10T00:19:57.489276Z"
},
"tags": [
"solution",
"inclass_exercise"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq8AAAFzCAYAAAAQdeBUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABRYUlEQVR4nO3de1wU9f4/8NcudxjACwKiKJrmXfRoEna1Q5mal7K+pqVmalZWFmV5J7uIqamVpnlJ7Vh56Xeyi2bHg1pHJS1cr3kXRFRQVBZY7ruf3x/jDqzsIovA7C6v5+MxDz/z2ffMvpkdhrezM5/RCCEEiIiIiIicgFbtBIiIiIiIKovFKxERERE5DRavREREROQ0WLwSERERkdNg8UpEREREToPFKxERERE5DRavREREROQ0WLwSERERkdNwVzuByjCZTLh48SL8/f2h0WjUToeIiIiIbiKEQE5ODsLCwqDV1tz5UacoXi9evIjw8HC10yAiIiKiWzh//jyaNm1aY+t3iuLV398fgLwxAgICVM6GiOj2GAwGhIWFAZD/c+7n56dyRkREty87Oxvh4eFK3VZTnKJ4NV8qEBAQwOKViJyem5ub0g4ICGDxSkQupaYv8eQNW0RERETkNFi8EhEREZHTYPFKRERERE6DxSsREREROQ0Wr0RERETkNFi8EhEREZHTYPFKRERERE6DxSsREREROQ0Wr0RERETkNFi8EhEREZHTcIrHw1YXIQSKi4thMpnUToWIqoFWq4WHh0eNP4qQiIgcR50oXo1GIzIzM5GTk4Pi4mK10yGiauTh4QF/f38EBQXBzc1N7XSIiKiGuXzxajQacf78eRQWFiIwMBCSJMHNzY1naoicnBACRqMRubm5yMrKQn5+PsLDw1nAEhG5OJcvXjMzM1FYWIhmzZrBx8dH7XSIqJpJkoTAwECkpqYiMzMTISEhaqdEREQ1yKVv2BJCICcnB4GBgSxciVyYj48PAgICkJOTAyGE2ukQEVENcunitbi4GMXFxZAkSe1UiKiG+fv7K7/zRETkuly6eDWPKsBr4Ihcn/n3nKOJEBG5NpcuXs14cxaR6+PvORFR3WB38fr777+jf//+CAsLg0ajwaZNm265zM6dO/GPf/wDXl5eaNWqFVavXl2FVImIiIiorrO7eDUYDIiMjMTixYsrFZ+cnIx+/fqhV69eOHDgAF5//XWMGTMGv/76q93JEhEREVHdZvdQWX369EGfPn0qHb906VK0aNECH3/8MQCgXbt22LVrFxYsWIDevXvb9d4GgwH+/v7K14NFRUUoLi6Gu7s7vLy8LOIAy68RTSYThBDQaDTQaktrdqPRCEB+Uo853p5YIQRMJtNtxZrf73ZjAcvre639HI4We6vtbk9sdXyetb2f1PRnX1f2E/N2N79mdqtjhI+Pj7KO4uJiFBUVwc3NDd7e3lWKzcvLgxAC3t7eys9XUlKCwsJCaLVaq6OemPeXimLz8/NhMpng5eUFd3d3ZbmCggK7YjUaDXx9fZXYgoICGI1GeHp6wsPDw+5Yk8mE/Px8AICfn58SW1hYiJKSEnh4eMDT09PuWCEE8vLyAAC+vr7ljvn2xFbms6+O/cTaZ18d+4n587QntjKf/e3uJ7Y+z9vdT8p+nre7n9j6PKu6n9T2MaIq+4naxwjztqhpNX7Na2JiImJiYiz6evfujcTERJvLFBYWIjs722ICgLCwMGRmZipxc+fOhSRJeOWVVyyWDw4OhiRJuHTpktJ35coV6HQ6pKSkWMQePnwYOp0OBQUFSt/Vq1eh0+lw9uxZi9ijR49Cp9MpvywAcO3aNeh0Opw+fdoi9tixY9DpdMjNzVX6srKyoNPpcPLkSYvY48ePQ6fTQa/XK305OTnQ6XQ4duyYReypU6eg0+lw/fp1pc9gMECn0+Hvv/+2iD1z5gx0Oh2uXbum9OXn50On0+HIkSMWscnJydDpdBbbt7CwEDqdDocOHbKIPXfuHHQ6HS5fvqz0FRcXQ6fT4cCBAxax58+fh06nQ3p6utJnNBqh0+mg0+ksCo0LFy5Ap9PhwoULSp8QQokt+0c+PT0dOp0O58+ft3i/AwcOQKfTWdxxfvnyZeh0Opw7d84i9tChQ9DpdCgsLFT6MjMzodPpkJycbBF75MgR6HQ65QAMlH72Z86csYj9+++/odPpLH6Jr1+/Dp1Oh1OnTlnEmveTnJwcpU+v10On0+H48eMWsSdPnoROp0NWVpbSl5uba3U/OX36dLnPPi8vDzqdDkePHrWIPXv2LHQ6Ha5evar0FRQUQKfT4fDhwxaxKSkp0Ol0uHLlitJXVFQEnU6HgwcPWsSmpqZCp9MhIyND6SspKVE+z7LS0tKg0+ksfmdNJpMSW/YmrEuXLkGn0yEtLc1iHceOHcP58+ctfjdudYxITU1V+hYvXgxJkjB69GiL2IiICEiSZLGNV69eDUmS8PTTT1vEtm/fHpIkYf/+/Urf+vXrIUkSBgwYAGt2796ttH/++WdIklTumHn//fdDkiSLb6y2b98OSZIQHR1tEdunTx9IkoTvv/9e6fvjjz8gSRIiIyMtYgcPHgxJkvD1118rfYcPH4YkSWjdurVF7PDhwyFJEpYtW6b0nTlzBpIkoUmTJhax48aNgyRJ+OSTT5S+S5cuQZIk1KtXzyI2NjYWkiRh1qxZSp9er4ckSZAkCSUlJUr/1KlTIUkSpk6dqvSVlJQosWWPobNmzYIkSYiNjbV4v3r16pX7+/DJJ59AkiSMGzfOIrZJkyaQJMnid3zZsmWQJAnDhw+3iG3dujUkSbL4nfn6668hSRIGDx5sERsZGQlJkvDHH38ofd9//z0kSSp3cig6OhqSJGH79u1K36+//gpJknD//fdbxMbExECSJPz8889K3//+9z9IkoS77rrLInbAgAGQJAnr169X+vbv3w9JktC+fXuL2KeffhqSJGH58n8hKwvIyAC2bz8NSeqKpk374K+/gD17gJ07gb59F0KShmDcuK1Ytw7417+AefOuQZJiUb/+dCxYAMyZA8yaBfTsuRWS9DEefvhPvP02EBsLvPBCPiRpNSTpXxgzBhg1ChgxAoiMPAxJ+gmRkScweDAwaBDw2GNGSNJOSNLvePhhIx5+GHjoIaB16zRIUhJatLiAnj2Bu+8GevQAJOkYJOkEOncuQWQk0LkzEB6eBUlKQUjINbRtC7RpA9x5J+Dvnw5JSkdEhBEtWgAREUBISD4k6SoaNjSgaVOgSRMgLAzw98+FJOUgONiIkBAgOBho0KAEkpSPwMAiNGwINGggT/7+xZCkItSrJ1CvHhAYCAQEmCBJRvj7m+DvD/j7A5JkOfn5lU6SpIEkaeDjA2Xy89NCktzg66uBlxfg5QX4+GggSe7w83ODlxfg6SlPfn7ukCRPeHtr4eEBuLsDXl5aSJIX/Pw84e4OZfL19YAkecHb203pk2O94efnBTc3KJOPjyfCwrxRG2r8IQXp6enlBg0PCQlBdnY28vPzrZ6JiI+Px8yZM2s6NSIiIpdTXKxFejqQlQX8/bc/gBhkZ7fH6tVATo48nT79PIABWLbsbvz4I2AwAOnpbQD8gYsXA9GqFZCfL096/XoAXhg/Hhg/3vwudwI4iawswLIungQAWLVKnmRBAL5AcbFcoJYaCGAgfv8d+P13c58vAPlNVq4sG9sdQHccPw6U/r/eHUA/AEBCQtnYlgBa4tIloMz/UW6sA7A8dxMMIBh6PVDm/z4A7gAAWJ4fCQAQgLw8oMw5LAByjVPm//8AfAD4oKgIKHMOAUAgAKDMeS0AngA8YTTe3C+f1Sxzbg3yz+wOk+nmfvmsseVIgVoAXhACKCoq228+g1q2TwNzSWjZbx7F5eZYa/1apb+macRtjOit0Wjw/fffY9CgQTZj7rzzTowaNQqTJ09W+rZs2YJ+/fohLy/PavFaWFhocTYsOzsb4eHhuHjxIkJDQ+26bODcuXNo0aIFPD09edmAA8bysgHHuGxg7dq12LVrF/bv34/Dhw+jqKgIK1euxMiRI51iPwHk3/uUlBS0bNlSOa446leCBoNBGX9ar9c
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"t = np.linspace(-10, 10, 100)\n",
"sig = 1 / (1 + np.exp(-t))\n",
"plt.figure(figsize=(8, 4))\n",
"plt.plot([-10, 10], [0, 0], \"k-\")\n",
"plt.plot([-10, 10], [0.5, 0.5], \"k:\")\n",
"plt.plot([-10, 10], [1, 1], \"k:\")\n",
"plt.plot([0, 0], [-1.1, 1.1], \"k-\")\n",
"plt.plot(t, sig, \"b-\", linewidth=2, label=r\"$\\sigma(t) = \\frac{1}{1 + e^{-t}}$\")\n",
"plt.xlabel(\"t\")\n",
"plt.legend(loc=\"upper left\", fontsize=20)\n",
"plt.axis([-10, 10, -0.1, 1.1]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Predictions\n",
"\n",
"Can then make class predictions depending on whether the predicted probability $\\hat{p}$ is greater than 0.5, i.e.\n",
"\n",
"$$ \\hat{y} = \\biggl \\{\n",
"\\begin{split}\n",
"0,\\ \\text{if}\\ \\hat{p} < 0.5\\\\\n",
"1,\\ \\text{if}\\ \\hat{p} \\geq 0.5\n",
"\\end{split}\n",
",\n",
"$$\n",
"\n",
"where we recall \n",
"$\\hat{p} = h_\\theta(x) = \\sigma(\\theta^{\\rm T} x) \\quad \\text{and} \\quad \\sigma(t) = \\frac{1}{1+\\exp{(-t)}}.$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Note that $\\sigma(t) < 0.5$ when $t<0$ and $\\sigma(t) \\geq 0.5$ when $t \\geq 0$. \n",
"\n",
"That is, logistic regression predicts model 1 when $\\theta^{\\rm T} x$ is positive, and model 0 when it is negative. \n",
"\n",
"The decision boundary is defined by $\\theta^{\\rm T} x=0$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Cost functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider the cost function:\n",
"\n",
"$$ C(\\theta) = \\biggl \\{\n",
"\\begin{split}\n",
"-\\log(\\hat{p}),\\ \\text{if}\\ y=1\\\\\n",
"-\\log(1 - \\hat{p}),\\ \\text{if}\\ y=0\n",
"\\end{split}\n",
".\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Exercise: Plot the cost function for $y=1$ as a function of $\\hat{p}$."
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": [
"inclass_exercise"
]
},
"source": [
"Consider the cost function:\n",
"\n",
"$$ C(\\theta) = \\biggl \\{\n",
"\\begin{split}\n",
"-\\log(\\hat{p}),\\ \\text{if}\\ y=1\\\\\n",
"-\\log(1 - \\hat{p}),\\ \\text{if}\\ y=0\n",
"\\end{split}\n",
".\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:57.493545Z",
"iopub.status.busy": "2024-01-10T00:19:57.492986Z",
"iopub.status.idle": "2024-01-10T00:19:57.715480Z",
"shell.execute_reply": "2024-01-10T00:19:57.714807Z"
},
"tags": [
"solution",
"inclass_exercise"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqoAAAGRCAYAAACt0JGZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABTVUlEQVR4nO3dd3hUVf7H8c8kJKEloaYAoTdpoUMoAgqiKCtrQ1ACu3aDgvgTl7UAoqIiiqKCHVdkUVxFlyqigBhQWjB06SBJFCQVCCG5vz/OziSBJGSSSWaSvF/Pc5+5c+fcO99xluXDmXPPsVmWZQkAAADwMF7uLgAAAADIC0EVAAAAHomgCgAAAI9EUAUAAIBHIqgCAADAIxFUAQAA4JEIqgAAAPBIBFUAAAB4JIIqAAAAPBJBFQAAAB6JoAoAAACPRFAFABfatGmTevXqpWrVqslmsykmJsbdJQFAmUVQBVBmHDhwQPfdd5+aNm2qypUrKyAgQL1799Zrr72ms2fPuvz9oqOjNWXKFCUmJhaqfUZGhm699Vb9+eefevXVV/Xxxx+rUaNGLq/L3VJTUzV58mRde+21qlWrlmw2m+bNm+fusgCUQ5XcXQAAFMbSpUt16623ys/PT5GRkWrXrp3Onz+v9evX67HHHtPOnTv1zjvvuPQ9o6OjNXXqVI0ZM0Y1atS4bPsDBw7oyJEjevfdd3X33Xe7tBZPcvLkST3zzDNq2LChwsPDtWbNGneXBKCcIqgC8HiHDh3S7bffrkaNGum7775TaGio47WoqCjt379fS5cudWOFxu+//y5JhQq1hZWWlqZq1aq57HquEBoaqri4OIWEhGjz5s3q1q2bu0sCUE7x0z8Aj/fSSy8pNTVV77//fq6Qate8eXONGzfO8Xzbtm267rrrFBAQoOrVq+vqq6/Wxo0bc52TkpKi8ePHq3HjxvLz81NQUJAGDRqkrVu3SpKmTJmixx57TJLUpEkT2Ww22Ww2HT58OM8ax4wZo379+kmSbr31VtlsNvXv39+pmqZMmSKbzaZdu3Zp5MiRqlmzpvr06ZPvf5chQ4aocePGlxy3LEudO3dW37598z23OPz8/BQSElIi1waAnOhRBeDx/vvf/6pp06bq1avXZdvu3LlTffv2VUBAgCZOnCgfHx+9/fbb6t+/v9auXasePXpIku6//359/vnnGjt2rNq0aaNTp05p/fr12r17tzp37qybbrpJ+/bt07///W+9+uqrqlOnjiSpbt26eb7vfffdp/r16+v555/Xww8/rG7duik4ONipmuxuvfVWtWjRQs8//7wsy8r3s3br1k3Lly/X6dOnVbNmTcfxhQsXatu2bVq/fv0l52RkZCgpKemy/x0lqVatWvLyoj8DgBtZAODBkpKSLEnWjTfeWKj2w4YNs3x9fa0DBw44jp04ccLy9/e3rrzySsexwMBAKyoqqsBrzZgxw5JkHTp0qFDv/f3331uSrEWLFhWppsmTJ1uSrBEjRhTq/b7++mtLkrV69WrHsfPnz1vNmjWzhg4dWmCNhdkK87k3bdpkSbI+/PDDQtUMAM6gRxWAR0tOTpYk+fv7X7ZtZmamvvnmGw0bNkxNmzZ1HA8NDdXIkSP17rvvKjk5WQEBAapRo4Z++uknnThxQvXq1Sux+p2pye7+++8v1LXtY0O3bt2qq666SpL0zjvv6NChQ1q8eHGe54SHh2vVqlWFuj4/7wNwN4IqAI9mD3ApKSmXbfvHH3/ozJkzatWq1SWvXXHFFcrKytKxY8fUtm1bvfTSSxo9erTCwsLUpUsXDRkyRJGRkbnCpCs4U5NdkyZNCnXtkJAQ1a9fX9u2bZNkbryaNm2a7rzzTrVr1y7Pc2rWrKmBAwcW4ZMAQOlj8BEAjxYQEKB69eppx44dLr3ubbfdpoMHD2r27NmqV6+eZsyYobZt22r58uUufZ+iqFKlSqHbduvWzRFUX3nlFZ0+fVrPPPNMvu3Pnz+v+Pj4Qm2ZmZnF/iwAUBwEVQAe74YbbtCBAwe0YcOGAtvVrVtXVatW1d69ey95bc+ePfLy8lJYWJjjWGhoqB588EEtXrxYhw4dUu3atfXcc885XrfZbMWu3dmanNWtWzft3btXR48e1csvv6wHHnigwEUGoqOjFRoaWqjt2LFjRa4LAFyBn/4BeLyJEyfqk08+0d13363vvvvOcTe93YEDB7RkyRKNGzdO11xzjb766isdPnzYMXVTQkKCFixYoD59+iggIECZmZlKTU1VYGCg4xpBQUGqV6+e0tPTHcfs85cWdmWqvHh7exeqpqLq2rWrsrKyNHLkSFmWpSeeeKLA9oxRBVCWEFQBeLxmzZppwYIFGj58uK644opcK1NFR0dr0aJFGjNmjCTp2Wef1apVq9SnTx89+OCDqlSpkt5++22lp6frpZdekmTGuzZo0EC33HKLwsPDVb16dX377bfatGmTZs6c6XjfLl26SJKeeOIJ3X777fLx8dHQoUOdnoC/MDUVVdeuXSVJP/74o6ZMmZLv9Fl2rhqj+sYbbygxMVEnTpyQZKYQO378uCTpoYceyvWPAAAoMndPOwAAhbVv3z7rnnvusRo3bmz5+vpa/v7+Vu/eva3Zs2db586dc7TbunWrNXjwYKt69epW1apVrQEDBljR0dGO19PT063HHnvMCg8Pt/z9/a1q1apZ4eHh1ltvvXXJe06bNs2qX7++5eXlddkpm/KbnqowNVlW9vRUf/zxh1P/XRo3bmzVrVvXSklJceq84mjUqFGxprUCgMKwWVYBs0kDADzawYMH1bJlS73yyit6+OGH3V0OALgUQRUAyrDhw4dry5Yt2rVrl3x9fd1dDgC4FGNUAaCMSUxM1PLly7VmzRotWrRIy5cvJ6QCKJcIqgBQxqxevVojR45UgwYN9Pbbb2vw4MHuLgkASkSx5lF94YUXZLPZNH78+ALbLVq0SK1bt1blypXVvn17LVu2rDhvCwAV2s033yzLsnTs2DHdc8897i4HAEpMkYPqpk2b9Pbbb6tDhw4FtouOjtaIESN01113adu2bRo2bJiGDRvm8lVmAAAAUL4U6Waq1NRUde7cWW+99ZaeffZZdezYUbNmzcqz7fDhw5WWlqYlS5Y4jvXs2VMdO3bU3Llz8zwnPT0916TbWVlZ+vPPP1W7dm2XrBQDAAAA17IsSykpKapXr568vFyz+GmRxqhGRUXp+uuv18CBA/Xss88W2HbDhg2aMGFCrmODBw/W4sWL8z1n+vTpmjp1alFKAwAAgBsdO3ZMDRo0cMm1nA6qCxcu1NatW7Vp06ZCtY+Pj79kucPg4GDFx8fne86kSZNyhdukpCQ1bNhQx44dK3CpQcuSQkKkc+ekVq2kn38uVIkAAAAopuTkZIWFhcnf399l13QqqB47dkzjxo3TqlWrVLlyZZcVcTE/Pz/5+fldcjwgIOCya2I3bCjt2yf99pvk7y8xUgAAAKD0uHKYplMDCLZs2aLff/9dnTt3VqVKlVSpUiWtXbtWr7/+uipVqqTMzMxLzgkJCVFCQkKuYwkJCQoJCSle5flo2NA8pqZKiYkl8hYAAAAoBU4F1auvvlqxsbGKiYlxbF27dtUdd9yhmJgYeXt7X3JORESEVq9enevYqlWrFBERUbzK82EPqpJ07FiJvAUAAABKgVM//fv7+6tdu3a5jlWrVk21a9d2HI+MjFT9+vU1ffp0SdK4cePUr18/zZw5U9dff70WLlyozZs365133nHRR8gtLCx7/+hR6TKzZwEAAMBDuWbugByOHj2quLg4x/NevXppwYIFeueddxQeHq7PP/9cixcvviTwukrOHtXDh0vkLQAAAFAKijSPamlLTk5WYGCgkpKSLnsz1fr1Ut++Zv/ee6W33y6FAgEAACo4Z/JaYbm8R9XdOnbMvtN/yxa3lgIAAIBiKHdBtXp1M4eqJMXGSufPu7ceAAAAFE25C6qS1KWLeTx/Xtq50721AAAAoGjKZVDt3Dl7n5//AQAAyqZyGVTtPaqStHWr++oAAABA0Tk1j2pZ0bFj9j49qgAAy7KUkZGhrKwsd5cCeBwvLy/
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ph = np.linspace(0.01, 0.99, 100)\n",
"cost_one = -np.log(ph)\n",
"plt.figure(figsize=(8, 4))\n",
"plt.plot([0, 0], [-1.1, 1.1], \"k-\")\n",
"plt.plot(ph, cost_one, \"b-\", linewidth=2, label=r\"$-\\log{(\\hat{p})}$\")\n",
"plt.xlabel(\"$\\hat{p}$\")\n",
"plt.legend(loc=\"upper right\", fontsize=20)\n",
"plt.axis([0, 1, 0, 4]);\n",
"plt.title('Cost for $y=1$');"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"inclass_exercise"
]
},
"source": [
"What can you say intuitively about the cost function at the edges of the domain?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
},
"tags": [
"solution",
"inclass_exercise"
]
},
"source": [
"- For $\\hat{p}=1$, $C(\\hat{p}) = 0$.\n",
"- For $\\hat{p}=0$, $C(\\hat{p}) \\rightarrow \\infty$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"inclass_exercise"
]
},
"source": [
"### Exercise: Plot the cost function for $y=0$ as a function of $\\hat{p}$."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:57.718928Z",
"iopub.status.busy": "2024-01-10T00:19:57.718470Z",
"iopub.status.idle": "2024-01-10T00:19:57.934523Z",
"shell.execute_reply": "2024-01-10T00:19:57.933872Z"
},
"tags": [
"inclass_exercise",
"solution"
]
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqoAAAGRCAYAAACt0JGZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABSEUlEQVR4nO3deVxVdf7H8TegXFwANWVRwX03NXe0ckkzdZyspkyd0BbLwhnNXzXZpmZFk1k2mlultjlMWurkUpmmZlq50bg7Ki6Z4BKCooLA+f1x5l5AFrlwuecCr+fjcR6cc+45536uN+3N95zv9+tlGIYhAAAAwMN4W10AAAAAkBeCKgAAADwSQRUAAAAeiaAKAAAAj0RQBQAAgEciqAIAAMAjEVQBAADgkQiqAAAA8EgEVQAAAHgkgioAAAA8EkEVAAAAHomgCgAutHXrVnXr1k1VqlSRl5eXYmNjrS4JAEotgiqAUuPw4cN67LHH1LBhQ/n5+SkgIEDdu3fXO++8o8uXL7v8/TZv3qxJkybp/PnzhTr+6tWruvfee/X777/r7bff1scff6x69eq5vC5PkJqaqr/97W+qXbu2KlWqpC5dumjNmjVWlwWgjKlgdQEAUBgrV67UvffeK5vNpsjISLVu3VppaWnatGmTnn76ae3Zs0fz5s1z6Xtu3rxZkydP1siRI1WtWrXrHn/48GEdO3ZM7733nh555BGX1uJpRo4cqSVLlmjcuHFq0qSJFi5cqAEDBui7777TzTffbHV5AMoIgioAjxcXF6f7779f9erV07p16xQaGup4LSoqSocOHdLKlSstrNB0+vRpSSpUqC2slJQUValSxWXXc4Wff/5ZMTExmjp1qp566ilJcvzy8Mwzz2jz5s0WVwigrODWPwCP98Ybb+jixYv64IMPcoRUu8aNG2vs2LGO7Z07d6p///4KCAhQ1apVddttt+nHH3/Mcc6FCxc0btw41a9fXzabTUFBQerbt6927NghSZo0aZKefvppSVKDBg3k5eUlLy8vHT16NM8aR44cqR49ekiS7r33Xnl5ealnz55O1TRp0iR5eXlp7969GjZsmKpXr15g6+SAAQNUv379XPsNw1D79u11yy235HtucSxZskQ+Pj569NFHHfv8/Pz08MMPa8uWLTpx4kSJvC+A8ocWVQAe78svv1TDhg3VrVu36x67Z88e3XLLLQoICNAzzzyjihUrau7cuerZs6c2bNigLl26SJJGjx6tJUuWaMyYMWrZsqXOnTunTZs2ad++fWrfvr3uvvtuHTx4UP/85z/19ttvq2bNmpKkWrVq5fm+jz32mOrUqaPXXntNf/3rX9WpUycFBwc7VZPdvffeqyZNmui1116TYRj5ftZOnTpp9erVSkxMVPXq1R37Y2JitHPnTm3atCnXOVevXlVSUtJ1/xwlqUaNGvL2zt2esXPnTjVt2lQBAQE59nfu3FmSFBsbq7CwsEK9BwAUyAAAD5aUlGRIMu68885CHT948GDD19fXOHz4sGPfb7/9Zvj7+xu33nqrY19gYKARFRVV4LWmTp1qSDLi4uIK9d7fffedIclYvHhxkWqaOHGiIckYOnRood7v3//+tyHJWLt2rWNfWlqa0ahRI2PQoEEF1liYJb/P3apVK6N379659u/Zs8eQZMyZM6dQ9QPA9dCiCsCjJScnS5L8/f2ve2xGRoa++eYbDR48WA0bNnTsDw0N1bBhw/Tee+8pOTlZAQEBqlatmn766Sf99ttvql27donV70xNdqNHjy7UtTt16iRJ2rFjh3r37i1JmjdvnuLi4rRs2bI8z2nbtm2he+eHhITkuf/y5cuy2Wy59vv5+TleBwBXIKgC8Gj2AHfhwoXrHnvmzBldunRJzZo1y/VaixYtlJmZqRMnTqhVq1Z64403NGLECIWFhalDhw4aMGCAIiMjc4RJV3CmJrsGDRoU6tohISGqU6eOdu7cKcnseDVlyhT9+c9/VuvWrfM8p3r16urTp08RPkmWSpUqKTU1Ndf+K1euOF4HAFegMxUAjxYQEKDatWtr9+7dLr3ufffdpyNHjmjGjBmqXbu2pk6dqlatWmn16tUufZ+icCboderUyRFU33rrLSUmJurll1/O9/i0tDTFx8cXasnIyMjzGqGhoTp16lSu/fZ9JdlCDaB8IagC8Hh/+MMfdPjwYW3ZsqXA42rVqqXKlSvrwIEDuV7bv3+/vL29c3TyCQ0N1RNPPKFly5YpLi5ON9xwg1599VXH615eXsWu3dmanNWpUycdOHBAx48f15tvvqnHH3+8wEkGNm/erNDQ0EIt+fXeb9eunQ4ePOh4LMPup59+crwOAK7ArX8AHu+ZZ57Rp59+qkceeUTr1q1z9Ka3O3z4sFasWKGxY8fq9ttv1/Lly3X06FHH0E0JCQlatGiRbr75ZgUEBCgjI0MXL15UYGCg4xpBQUGqXbt2jlva9vFLCzszVV58fHwKVVNRdezYUZmZmRo2bJgMw9Dzzz9f4PGueEb1T3/6k958803NmzfPMY5qamqqFixYoC5dutDjH4DLEFQBeLxGjRpp0aJFGjJkiFq0aJFjZqrNmzdr8eLFGjlypCTplVde0Zo1a3TzzTfriSeeUIUKFTR37lylpqbqjTfekGQ+71q3bl396U9/Utu2bVW1alV9++232rp1q6ZNm+Z43w4dOkiSnn/+ed1///2qWLGiBg0a5PQA/IWpqag6duwoSfrhhx80adKkfIfPsnPFM6pdunTRvffeqwkTJuj06dNq3LixPvzwQx09elQffPBBsa4NADlYPewAABTWwYMHjVGjRhn169c3fH19DX9/f6N79+7GjBkzjCtXrjiO27Fjh9GvXz+jatWqRuXKlY1evXoZmzdvdryemppqPP3000bbtm0Nf39/o0qVKkbbtm2NWbNm5XrPKVOmGHXq1DG8vb2vO1RVfsNTFaYmw8ganurMmTNO/bnUr1/fqFWrlnHhwgWnziuOy5cvG0899ZQREhJi2Gw2o1OnTsZXX33ltvcHUD54GUYBo0kDADzakSNH1LRpU7311lv661//anU5AOBSBFUAKMWGDBmi7du3a+/evfL19bW6HABwKZ5RBYBS5vz581q9erXWr1+vxYsXa/Xq1YRUAGUSQRUASpm1a9dq2LBhqlu3rubOnat+/fpZXRIAlIhijaP6+uuvy8vLS+PGjSvwuMWLF6t58+by8/PTjTfeqFWrVhXnbQGgXLvnnntkGIZOnDihUaNGWV0OAJSYIgfVrVu3au7cuWrTpk2Bx23evFlDhw7Vww8/rJ07d2rw4MEaPHiwy2eZAQAAQNlSpM5UFy9eVPv27TVr1iy98sorateunaZPn57nsUOGDFFKSopWrFjh2Ne1a1e1a9dOc+bMyfOc1NTUHINuZ2Zm6vfff9cNN9zgkpliAAAA4FqGYejChQuqXbu2vL1dM/lpkZ5RjYqK0sCBA9WnTx+98sorBR67ZcsWjR8/Pse+fv36admyZfmeEx0drcmTJxelNAAAAFjoxIkTqlu3rkuu5XRQjYmJ0Y4dO7R169ZCHR8fH59rusPg4GDFx8fne86ECRNyhNukpCSFh4frxIkTxZpqEAAAANe3cKE0dqy5/vbb0kMPXf+c5ORkhYWFyd/f32V1OBVUT5w4obFjx2rNmjXy8/NzWRHXstlsstlsufYHBAQQVAEAAErYmTNZ602bSs7EL1c+pulUUN2+fbtOnz6t9u3bO/ZlZGRo48aNmjlzplJTU+Xj45PjnJCQECUkJOTYl5CQoJCQkGKUDQAAgJJy4kTWeni4dXU49aTrbbfdpl27dik2NtaxdOzYUcOHD1dsbGyukCpJERERWrt2bY59a9asUURERPEqBwAAQIk4fjxr3cqg6lSLqr+/v1q3bp1jX5UqVXTDDTc49kdGRqpOnTqKjo6WJI0dO1Y9evTQtGnTNHDgQMXExGjbtm2aN2+eiz4CAAAAXMneohoQ4Nxtf1dzzdgB2Rw/flynTp1ybHfr1k2LFi3SvHnz1LZtWy1ZskTLli3LFXgBAABgvYyMrKBqZWuqVMRxVN0
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"cost_zero = -np.log(1-ph)\n",
"plt.figure(figsize=(8, 4))\n",
"plt.plot([0, 0], [-1.1, 1.1], \"k-\")\n",
"plt.plot(ph, cost_zero, \"b-\", linewidth=2, label=r\"$-\\log{(1-\\hat{p})}$\")\n",
"plt.xlabel(\"$\\hat{p}$\")\n",
"plt.legend(loc=\"upper left\", fontsize=20)\n",
"plt.axis([0, 1, 0, 4]);\n",
"plt.title('Cost for $y=0$');"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"inclass_exercise"
]
},
"source": [
"What can you say intuitively about the cost function at the edges of the domain?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
},
"tags": [
"solution",
"inclass_exercise"
]
},
"source": [
"- For $\\hat{p}=0$, $C(\\hat{p}) = 0$.\n",
"- For $\\hat{p}=1$, $C(\\hat{p}) \\rightarrow \\infty$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Log-loss function for logistic regression\n",
"\n",
"Cost function can be written by the single expression\n",
"\n",
"$$\n",
"C(\\theta) = -\\frac{1}{m} \\sum_{i=1}^m \n",
"\\left [ \n",
"y^{(i)} \\log(\\hat{p}^{(i)})\n",
"+\n",
"(1 - y^{(i)}) \\log(1 - \\hat{p}^{(i)})\n",
"\\right],\n",
"$$\n",
"\n",
"since $y^{(i)}$ is always 0 or 1 and we thus recover the separate cases considered above."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"#### Aside: statistical interpretation\n",
"\n",
"Interpret $p$ as probability of target $y$:\n",
"\n",
"$$\n",
"{\\rm P}(y \\: \\vert \\: p ) = p^{y} (1-p)^{1-y}\n",
"$$\n",
"\n",
"$$\n",
"{\\rm log}{\\rm P}(y \\: \\vert \\: p, x ) = y {\\rm log}(p) + (1-y) {\\rm log}(1-p)\n",
"$$\n",
"\n",
"See [MacKay](http://www.inference.org.uk/itila/book.html) [Chapter 41] for further details."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Minimising the cost function\n",
"\n",
"No closed form solution like linear regression.\n",
"\n",
"But since the cost function is convex guaranteed to find global minimum by gradient descent."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Derivative of the cost function\n",
"\n",
"\\begin{align*}\n",
"\\frac{\\partial C}{\\partial \\theta} \n",
"&=\n",
"\\frac{1}{m} \\sum_{i=1}^m \n",
"\\left[ \\sigma\\left(\\theta^{\\rm T} x^{(i)} \\right) - y^{(i)} \\right] \n",
"x^{(i)} \\\\\n",
"&=\n",
"\\frac{1}{m} \n",
"X^{\\rm T}\n",
"\\left[ \\sigma\\left(X \\theta \\right) - y \\right] \\\\\n",
"&=\n",
"\\frac{1}{m} \n",
"X^{\\rm T}\n",
"\\left[ h_\\theta\\left(X\\right) - y \\right]\n",
"\\end{align*}\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Similarity with linear regression\n",
"\n",
"Identical to linear regression (up to factor of 2 depending on conventions adopted) but with a different prediction function:\n",
"\n",
"$$h_\\theta(x) = \\sigma(\\theta^{\\rm T} x), $$\n",
"\n",
"instead of\n",
"\n",
"$$h_\\theta(x) = \\theta^{\\rm T} x. $$\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Example of logistic regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider [Iris flower data](https://en.wikipedia.org/wiki/Iris_flower_data_set) again."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:57.939374Z",
"iopub.status.busy": "2024-01-10T00:19:57.937962Z",
"iopub.status.idle": "2024-01-10T00:19:58.295225Z",
"shell.execute_reply": "2024-01-10T00:19:58.294555Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"['data',\n",
" 'target',\n",
" 'frame',\n",
" 'target_names',\n",
" 'DESCR',\n",
" 'feature_names',\n",
" 'filename',\n",
" 'data_module']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn import datasets\n",
"iris = datasets.load_iris()\n",
"list(iris.keys())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.299485Z",
"iopub.status.busy": "2024-01-10T00:19:58.298085Z",
"iopub.status.idle": "2024-01-10T00:19:58.303784Z",
"shell.execute_reply": "2024-01-10T00:19:58.303189Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Train model\n",
"\n",
"Use petal width to classify whether Virginica or not."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.307016Z",
"iopub.status.busy": "2024-01-10T00:19:58.306450Z",
"iopub.status.idle": "2024-01-10T00:19:58.311305Z",
"shell.execute_reply": "2024-01-10T00:19:58.310697Z"
}
},
"outputs": [],
"source": [
"# Set up training data\n",
"X_1d = iris[\"data\"][:, 3:] # petal width\n",
"y = (iris[\"target\"] == 2).astype(int) # 1 if Iris-Virginica, else 0"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.315162Z",
"iopub.status.busy": "2024-01-10T00:19:58.313802Z",
"iopub.status.idle": "2024-01-10T00:19:58.370892Z",
"shell.execute_reply": "2024-01-10T00:19:58.370148Z"
}
},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"log_reg = LogisticRegression(random_state=42)\n",
"log_reg.fit(X_1d, y);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that Scikit-Learn automatically adds $\\ell_2$ regularizer to cost function."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Prediction"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.375263Z",
"iopub.status.busy": "2024-01-10T00:19:58.373897Z",
"iopub.status.idle": "2024-01-10T00:19:58.555888Z",
"shell.execute_reply": "2024-01-10T00:19:58.555274Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArgAAAF4CAYAAAC/wIoGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB9iUlEQVR4nO3dd3gUVdvH8e+mAyGhhxYICIQOUqUXgVAFLCAgXQGxoPiogHQfBRUeQUCqNFFRkA7SQXoHld47oQlJKOnz/jFvFpYESEjZTfL7XNdcu3NmZufeIVnunL3nHIthGAYiIiIiImmEk70DEBERERFJSkpwRURERCRNUYIrIiIiImmKElwRERERSVOU4IqIiIhImqIEV0RERETSFCW4IiIiIpKmKMEVERERkTTFxd4BOILo6GguX75M5syZsVgs9g5HRERERB5hGAYhISHkzZsXJ6cn99EqwQUuX76Mr6+vvcMQERERkae4cOEC+fPnf+I+SnCBzJkzA+YF8/LysnM0IiIiIvKo4OBgfH19rXnbkyjBBWtZgpeXlxJcEREREQcWn3JS3WQmIiIiImmKElwRERERSVOU4IqIiIhImqIEV0RERETSFCW4IiIiIpKmKMEVERERkTTF4RLcTZs20aJFC/LmzYvFYmHRokVPPWbjxo1UqFABd3d3ihQpwsyZM5M9ThERERFxTA6X4N69e5dy5coxYcKEeO1/5swZmjVrRr169Thw4AAffPABb775JqtWrUrmSEVERETEETncRA9NmjShSZMm8d5/0qRJFCpUiNGjRwNQokQJtmzZwrfffktAQEByhSkiIiIiDsrhEtyE2r59Ow0aNLBpCwgI4IMPPrBPQPE0Zw4sXQpOTk9eiheHDz+0PXb0aLhy5enHvvgiVKv24Lh792DqVLBYnnycqys0aQLZsj049soVOHwYXFziXlxdzUc3N3h0euiwMPPRxcV8/XhMQCIiIiIOzDAMIqIjCI8Kx8XJBQ8XD3uHZCPVJ7iBgYH4+PjYtPn4+BAcHMz9+/fJkCFDrGPCwsIIi8m6MOc2Tml//QW//fb0/erXj53gzp4Nf//99GPd3GwT3KAgiG/ev3+/bYK7ejV06fL043LlgqtXbdu6dYOff36wHpMQu7ubi5ub+fjyy/DVV7bHdu5sJuYx+zx6TMzStCmUKvXguLt3zWuUMSNkymQ+xixubvG7BiIiIo4o2ojmXsQ97oTf4U74He5H3Cc0MpT7keZjaGRorLaYdZu2h/YPjwp/4hIRFWG7Hh1hjWfkiyP5tOandrwisaX6BPdZjBgxgmHDhtk1hujo+O3nFEeV9LMeG9/jwExCHxYZ+WzHxXVsZKS53L9v237tWuxjlyyB27efft58+WwT3KNHoXr1x8f4cMK7axdkz/5g++rV8Mcf4OUF3t62jw8/z5IFPD2fHpuIiIhhGNyPvM+t+7e4HXqb26G3uRX64Pnt0NvWhPVpy92Iu/Z+OzbCo8LtHUIsqT7BzZ07N1cf6TK8evUqXl5ecfbeAvTv35++ffta14ODg/H19U3WOB81aBD06WMmnU9aMmWKfezs2WYP5dOOLVHC9risWeGXXx5sN4zYx0RGQlQU5M1re2y5cvDZZw+S08hIiIiwXY+MNJO/R5UoAbVrxz42LMx28fKKfexDHe1PlDGj7fq9e4/fNzISgoPNBcwe4Idt3Qpjxjz9nJUqwe7dtm2DB8PFi2bCnD075Mjx4Hn27ODjY/aMx/WHi4iIpA7RRjT/3v+X63evc/3e9diP967z7/1/HySy/5/UPtzr6cgsWHB3ccfN2e2pi6uTK35Z/OwdciypPsGtVq0aK1assGlbs2YN1R7+bv4R7u7uuD+a1aSwLFnM5Vk8//yzHZcxI7z++rMdW6mSuTyLoUPN5VmcPfsgAQ4Pj50Ux7RVqWJ7XO7cZjnGvXu2y927sdcf/TsovhUrcSXkS5fCgQNPPs7VFYYPh379HrSFh8MPP5hx58lj9kjnyRN3j7iIiCSPiKgIrty5wuWQy1wOucyl4EvmY4j5GHgnkGt3r3Hz/k2ijQR8LZpI7s7ueLp5xrlkcstEBpcMZHDJgIeLBxlc///x/9fjaotrn5iE1tXJFWcn5xR7b8nF4f77vHPnDidPnrSunzlzhgMHDpAtWzYKFChA//79uXTpErNnzwagV69ejB8/nk8++YRu3bqxfv16fvvtN5YvX26vtyBJKFeuZzuuaFH49ttnO/ajj6BNG7NmOTg49mPM8zJlYh978+bTXz8iInZpw9Wr0Lu3bZuzs3nDXoECULCg+VigALz2mm19tIiIxM+d8DucvX2Ws7fPcubWGfPxtvl4KeQS1+9ex8BIsvNZsODt4U0Wjyxk8chCVo+scT6PWTK7Z46dwLpmwtXZNcliSi8cLsHds2cP9erVs67HlBJ07tyZmTNncuXKFc6fP2/dXqhQIZYvX86HH37I2LFjyZ8/P9OmTdMQYfLM8uePPRJEfO3aBTdumInuzZuxn1+9CoGBUKiQ7XGBgbFfKyoKzp0zl82bH7Q3amSb4K5aZd6wWLTog6VIkdhlGyIi6UFwWDDHbhzj2M1jHLtxjOP/HufMrTOcuX2GG/duJOq1M7hkIGemnOTMmNP6mCtTLpv1mMdsGbLh7eGNk0U1afZgMQwj6f5USaWCg4Px9vYmKCgIr7i+dxZJZteuwbJlZqJ75YpZx3vuHJw/b9srbLFAaKjtSBADB8IXX8R+zXz5zGHmypSB0qWhQoVnL28REXE0N+7d4O+rf/P31b85euMox24e4+iNowTeiaPH4CmcLc7kyZyHvJnzki9zPttHrwfrXu5eWDTWpd0kJF9zuB5ckfQoVy5zOLW43LkDFy6Yye61a7GHOTtxIu7jLl0yl3XrzPW6dWHDBtt9tmwxE2E/P41PLCKOKTI6kiPXj/D31b/56+pf1qT2yp0r8X4NCxbyeeWjUJZC+GXxe/CY1XzM75UfFyelRGmJ/jVFHJynpzkSxaOjYsSYNg0+/dRMdB9ejh+37f0tXdr2OMMw63kDA80RHmJuJIxZ8uVT0isiKcswDM7cPsOuS7usy74r+7gfef/pBwM5M+bEP4c//tn9KZ6jOP7Z/fHP4Y9fFj/cnDUIenqiEgVUoiBp17Vr8M8/cPAglC8Pdeo82Hbp0pNrjfPkgZo1oVYtMxHOnTvZwxWRdCYsMow9l/fw57k/2XJ+C7su7eLm/affrZvVIyvlcpejbK6ylPUpS6lcpfDP7k/WDFlTIGqxl4Tka0pwUYIr6VNgIEycaI7lu3u3eRPc4+zZAxUrPliPmV1Ow5iJSELci7jH9gvb2XRuE5vOb2LHxR2ERoY+8ZjCWQtTOW9lyucuTzmfcpT1KUvezHlVC5sOqQZXRJ4qd26ImdDPMMw63927zWR2925zRIiQELNEolw522MnTTKPffFFaNjQHNnhuedS/j2IiGMzDIO/r/7NqlOrWHVqFVvOb3nirFc5MuagSr4qVMlbhSr5qlA5X2VyZMyRghFLWqEEV0SwWB6Ms/vKK2ZbZCT8/bc52cajPbWrV5vjAS9caC4AhQtDQAC89BLUqxd7hjgRSR9uh95m5cmVrDy5ktWnVj/xZjC/LH7ULlib2gVqU7tgbYpkK6KeWUkSKlFAJQoiCdWtGyxZ8viJLTJnhiZN4O23zdEbRCRtC7wTyOKji1l4dCHrz6x/7JS0fln8aFi4IXUK1qFWwVoU8C6QwpFKaqYSBRFJVtOnQ3Q07N9v9uauXg1bt5qztIFZ2vDbb2ZPrhJckbTpfNB55h2ax4KjC9h+YXucM4BldM1IXb+6NH6uMQFFAiiarah6aCVFqAcX9eCKJIU7d8xEd9Eic9KKW7fMkRry5n2wz+7d5qQU7dtD8+aabU0ktbl1/xbzD89nzj9z2HRuU5z7FPAuQCv/VrTwb0GtArVwd1G9kiQNjaKQQEpwRZJWZCTs2wdVqti29+kD331nPvf0hJdfhu7dzaHI1Kkj4pgioiJYdnwZc/6Zw7Ljy+K8SaxkzpK0Lt6a1sVbUyFPBfXSSrJQgptASnBFUkaFCmZ
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_1d_new = np.linspace(0, 3, 1000).reshape(-1, 1)\n",
"y_1d_proba = log_reg.predict_proba(X_1d_new)\n",
"\n",
"plt.figure(figsize=(8,4))\n",
"plt.plot(X_1d_new, y_1d_proba[:, 1], \"g-\", linewidth=2, label=\"Iris-Virginica\")\n",
"plt.plot(X_1d_new, y_1d_proba[:, 0], \"b--\", linewidth=2, label=\"Not Iris-Virginica\")\n",
"plt.xlabel(\"Petal width (cm)\", fontsize=14)\n",
"plt.ylabel(\"Probability\", fontsize=14)\n",
"plt.legend(loc=\"center left\", fontsize=14);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Decision boundary"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Recall the decision boundary is given by $\\hat{p}=0.5$ or, equivalently, $\\theta^{\\rm T} x=0$."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.559105Z",
"iopub.status.busy": "2024-01-10T00:19:58.558652Z",
"iopub.status.idle": "2024-01-10T00:19:58.565651Z",
"shell.execute_reply": "2024-01-10T00:19:58.565067Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([1.66066066])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"decision_boundary = X_1d_new[y_1d_proba[:, 1] >= 0.5][0]\n",
"decision_boundary"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Updating plot with decision boundary and training data"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.568895Z",
"iopub.status.busy": "2024-01-10T00:19:58.568352Z",
"iopub.status.idle": "2024-01-10T00:19:58.801548Z",
"shell.execute_reply": "2024-01-10T00:19:58.800895Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAAF4CAYAAACxafRpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACe+ElEQVR4nOzdd1gUVxfA4d/SFhRBFEED2HvvvXdj7y32EjUmRky+aGKNUTSWaKLRWNFYY9fYgxq7xh5rYhcUFAsg0ne+PyasrBQpC0s57/PMw86dOzNnUfRw9865GkVRFIQQQgghhMiCzEwdgBBCCCGEEKYiybAQQgghhMiyJBkWQgghhBBZliTDQgghhBAiy5JkWAghhBBCZFmSDAshhBBCiCxLkmEhhBBCCJFlSTIshBBCCCGyLAtTB5Ae6HQ6Hj9+TI4cOdBoNKYORwghhBBCvENRFIKCgvjggw8wMzPeeK4kw8Djx49xc3MzdRhCCCGEEOI9Hj16hKurq9Gul+6S4aNHjzJr1izOnz/PkydP2LZtGx06dEjwnCNHjuDu7s61a9dwc3Nj/Pjx9O/fP9H3zJEjB6B+c+3s7FIQvRBCCCGESA2BgYG4ubnp8zajUdKZPXv2KN98842ydetWBVC2bduWYP+7d+8q2bJlU9zd3ZXr168rP/30k2Jubq7s27cv0fcMCAhQACUgICCF0QshRPIcvHNQKbWglHLwzkGjXG/m8ZmK1VQrZebxmUa5nqIYP8bUuKbVt1YKk1GsvrUyyvUURVGq/VJNYTJKtV+qGe2axX8srjAZpfiPxdPl9RRFUQbvGKwwGWXwjsFGu2bumbkVJqPknpnbKNdLjRhT45rCOFIrX9MoiqIYN702Ho1G896R4a+++ordu3dz9epVfVuPHj149eoV+/btS9R9AgMDsbe3JyAgQEaGhRBpTlEUaiyrwV+P/6LaB9U4M/hMip5f0Ol05JiRgzcRb8hmmY2gsUEpnl9n7BhT45rBwcHYzrbV77/+4jXZs2dPUYwRERFYTbfS74d/HY6lpWWKrhkeHo7WQ6vfDxsXhpWVVQJnpO31AKKiorD8zhIFBQ0aIsZHYG5unqJrhoaGYjPTRr8f8lUI1tbW6SrG1LimMJ7UytfS3TSJpDp16hRNmzY1aGvRogWff/55kq/13XdgYwMaDZiZxd40GqhaFRo1Mjxv0aL4z4m5NWgA+fK9Pe/pUzhzxvD68d23bl31azRfX3j5EiwsDDdzc8N9S0t1E0KkXwfuHOCvx38B8Nfjvzhw5wAtirZI9vU8jnvwJuINAG8i3uBx3INv6n+T5jEWLlwYHx8fXFxcuHv3rlGumRCHuQ6x9sMnhCf7egA1lteItX9h2IUUXbP0otKx9m+Pup1urgcwaOcgFNSxMgWFQTsH4dnRM0XXdJnnEmv/+djnyb5easSYGtcU6V+GT4Z9fX1xdnY2aHN2diYwMJCQkBBsbGxinRMWFkZYWJh+PzAwEIBZs95/v88/N0yGFQVGjEhcrHv3GibDFy9Cu3aJO1enM9yfPRvmzHn/eU2bwsGDhm01a8LNm+9PpEePhphTr1++hMGDwcoKtFr1a3yv+/WDPHnenvvwIVy+rB6ztlZ/6ciWLfZXS0vDpF+IzE5RFCYcnoC5xpwoJQpzjTkTDk+geZHmyRol1el0TD8+3aBt+vHpjKs7Ltmjw8mNMTw8XL8Z65rxCQ4OJkIXYdAWoYsgODg42aPDERERXPS7aNB20e8iERERyR4dDg8P586rOwZtd17dITw8PFmjuca+Hqijo6uvrDZoW31lNcvbLU/2KGloaCgvwl4YtL0Ie0FoaGiyRodTI8bUuGZmFKmLJDQylLDIMEIjQ9XXUW9fv3vs3eMxj4VHhROhizD8GhX/fsjrkFR5Txk+GU4ODw8PpkyZkqxz3/2/JCmTTN49990ENz4aTewEMTIycefG9fMbEKBu7/PypeH+69ewdWvi7tuqlWEyfPCgmki/zwcfgI+PYdv48XDkiJosRyfO2bODrS3Y2UGOHOpWtqw6gh7T06dq32zZJMkW6VPM0VGAKCUqRaOkMUeFo6V0dDi5MZYtWxYnJyecnJyMds34vDsqHLM9uaPD744Kx2xP7ujwu6O4MduTM5pr7OuB4ehotJSOkr47KhyzPTmjw6kRY2pc05QUReF1+GsCwgIICA0gMCyQgDD16+vw1wSHBxMcEWzw9U3kmzjbY36N1CUyAUkNoalz2QyfDOfNmxc/Pz+DNj8/P+zs7OIcFQYYN24c7u7u+v3opxO3b1cTLZ1OTXJ1uthbsWKxr7dqVfz9Y24lSxqeV7w4TJtm2Ceu68SVcFevro7aRkWpiXHMLWZbxYqxzy1cWL3mu33fPf/dX9ZjDKa/17sDEok9V6uN3XbtGpw48f5zBw2KnQwXLQpBQeovIra2atJsbw+5coGDg/o1Vy51JLtChbfnvXkDjx+rfXLmjPuXCiFS6t3R0WjJHSWNa1Q4WnJHh1MSY3zPbRj7fcc1KhwtuaPDcY0KR0vu6HBco7jRkjOaa+zrQdyjo9GSO0oa16hwtOSMDqdGjKlxTWPQKTpehrzkechznr95jv8bf/3r5yHPeRX6Sp/sRie6MV/rlESOuqVz5hpzrMytsNBaEESQ0a+f4ZPhWrVqsWfPHoO2gwcPUqtWrXjP0Wq1aOPIuho1Ukcak8LMDPr2Tdo50YoUga+/Tt65vXqpW3Ls3p288woWBG9vCA9Xt7Cw+F/HnA4CUKOGmviHhUFoqJpshoTE/po3b+z7hibyN8F3/+wURR3NBvWXisBAdXt35BmgYUPDZPjcOXWOdzQHB3ByAmdn9WvM1wMHxk7+hUiMd0dHoyV3lDSuUeFoyR0dNnaMqXHN+EaFYx5P6uhwfKPCMY8ndXQ4vlHcmMeTMppr7OtB3KOj0ZI7ShrfqHDM40kZHU6NGFPjmvF5Hf4a39e+PAl6gu9rX4Pt6Zun+kTX/40/L0NexhtXWrAytyK7ZXayW2XXf7WxsMHawtpg05pr37620CaqXWuhRWuuxdLcEitzK6zMrbA0U19Ht1maWWJpbomZRv0lPjAwEPsp9kZ/n+mumsTr16+5fVv94a1UqRJz586lUaNG5MqVi/z58zNu3Dh8fHxYvVr9De7evXuULVuWTz75hIEDB3Lo0CE+++wzdu/eTYsWifvHVKpJpH8REYZJc3CwOuIbvQUGQqlS6nzoaOHh0K3b2+PRfV+9Uq8T07FjhqPK27dDx47vj0ujUe9jEePXyqlT4ddfwdU1/s3RMfa0GZG1RFdSOP/4PDpij96YYUaVD6okusJCzAoS8UlqZQljx5ga13y3gkR8klJZ4t0KEvFJSmWJdys+xCexlSCMfT0wrKQQn6RWWHi3gkR8EltZIjViNNY1dYqOp8FPeRjwkEcBj3gY8JCHAQ/xDvI2SHyDI4ITFVdyaM212FvbY6e1w15rj721vcFXO60ddlo7bK1s9cltzNfvfrUwS19jplmmmsS5c+doFOMJtejpDP369cPT05MnT57w8OFD/fFChQqxe/duRo8ezfz583F1dWXZsmWJToRFxhBdFSMpf/etrNSkNi7h4eqc6Bcv1K1cOcPjefNC795vjz97ps4/jh5pjuboaJgIA9y+Df/+q24Jxda1K6xZY9h+/bo64pwrl8xxzuzCo8J5GPAwzoQQQIeOR4GPCI8KR2vx/qTndfhrQiMS/hglNDKU1+GvsbNO3A+SsWNMjWse8zmWqPse8zlGy+ItE9X3YcDD93f6r18RxyKJ6nv7ZeJGaG+/vE1p54RHfFPjegAvQl68dxRSQeFFyAvy2OZJsF+0K8+uJLpfdbfqJokxsdd8/uY5UURx+8Vtbr+4zZ2Xd9TEN1BNfL0DvQmPSln1kmh2Wjty2+Qmd7bcOGZzVF//t5/b5r+2bLlxsHZQE9//kt3E/hwKQ+luZNgUZGRYJNabN2pS/PQp+Pmp0z66dDHsM3AgbNmijkYnZPBgWLrUsM3JSU287eygUCF1fnfx4uqod6lS6rxz+SuaeTwKeMSzN8/iPe6U3QlXu8QvOXrG+0y
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(8, 4))\n",
"plt.plot(X_1d[y==0], y[y==0], \"bs\")\n",
"plt.plot(X_1d[y==1], y[y==1], \"g^\")\n",
"plt.plot([decision_boundary, decision_boundary], [-1, 2], \"k:\", linewidth=2)\n",
"plt.plot(X_1d_new, y_1d_proba[:, 1], \"g-\", linewidth=2, label=\"Iris-Virginica\")\n",
"plt.plot(X_1d_new, y_1d_proba[:, 0], \"b--\", linewidth=2, label=\"Not Iris-Virginica\")\n",
"plt.text(decision_boundary+0.02, 0.15, \"Decision boundary\", fontsize=14, color=\"k\", ha=\"center\")\n",
"plt.arrow(decision_boundary[0], 0.08, -0.3, 0, head_width=0.05, head_length=0.1, fc='b', ec='b')\n",
"plt.arrow(decision_boundary[0], 0.92, 0.3, 0, head_width=0.05, head_length=0.1, fc='g', ec='g')\n",
"plt.xlabel(\"Petal width (cm)\", fontsize=14)\n",
"plt.ylabel(\"Probability\", fontsize=14)\n",
"plt.legend(loc=\"center left\", fontsize=14)\n",
"plt.axis([0, 3, -0.02, 1.02]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Predictions depend on what side of decision boundary fall."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.804779Z",
"iopub.status.busy": "2024-01-10T00:19:58.804533Z",
"iopub.status.idle": "2024-01-10T00:19:58.811352Z",
"shell.execute_reply": "2024-01-10T00:19:58.810755Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"log_reg.predict([[1.7], [1.5]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Extending to two features"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.814372Z",
"iopub.status.busy": "2024-01-10T00:19:58.813922Z",
"iopub.status.idle": "2024-01-10T00:19:58.827801Z",
"shell.execute_reply": "2024-01-10T00:19:58.827253Z"
}
},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = (iris[\"target\"] == 2).astype(int) # 1 if Iris-Virginica, else 0\n",
" \n",
"C = 1000 # inverse regularization (smaller values correspond to stronger regularization)\n",
"log_reg = LogisticRegression(C=C, random_state=42)\n",
"log_reg.fit(X, y)\n",
"\n",
"x0, x1 = np.meshgrid(\n",
" np.linspace(2.9, 7, 500).reshape(-1, 1),\n",
" np.linspace(0.8, 2.7, 200).reshape(-1, 1),\n",
" )\n",
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
"\n",
"y_proba = log_reg.predict_proba(X_new)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:58.830821Z",
"iopub.status.busy": "2024-01-10T00:19:58.830145Z",
"iopub.status.idle": "2024-01-10T00:19:59.072590Z",
"shell.execute_reply": "2024-01-10T00:19:59.071979Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2YAAAHFCAYAAACZwuqXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdZXRU19eA8Wdm4u6QIEmwCO5uhSLFrS0ttBQoGtzd3UmwQnFpcYpbcXeLQAgeIO4ydt8P84Z/aUILE5JJ4fzWylolc8++ey4Tmn33uefIJEmSEARBEARBEARBEAxGbugEBEEQBEEQBEEQPneiMBMEQRAEQRAEQTAwUZgJgiAIgiAIgiAYmCjMBEEQBEEQBEEQDEwUZoIgCIIgCIIgCAYmCjNBEARBEARBEAQDE4WZIAiCIAiCIAiCgYnCTBAEQRAEQRAEwcCMDJ3Af4VWqyU8PBxra2tkMpmh0xEEQRAEQRAEwUAkSSIxMRE3Nzfk8o/T6xKF2XsKDw+nUKFChk5DEARBEARBEIQ84tmzZxQsWPCjxBKF2XuytrYGdBffxsYmy2M0sfHETFtKwqrtIGnB3Bz74d2x69sJuanJW8f+9NNP7Ny5kyJFijB79my+/PLLTPG0SYnErppN3KYloNaAiQn2XQZi12UQcnOLbL2ftJRYrh6Zwb3zvyAhYWRsTvn6QyhXtz8KY9NsxdaHJEmcfrWV9Q/GEpceAUAl56b8VGIG+S08cz0fQRAEQRAEQXiXhIQEChUq9KZG+BhkkiRJHy3aJywhIQFbW1vi4+PfWZhlSL8ZSITfZNLOXQfAuLgHzovHYtmkDgAnTpzgiy++eGtMq1atWLBgAZ6emYsQZWgQkdMGkHLuKABGboVxHrUAy0Ztsj2tMvLFLU5t8+Nl2FkAbJ2LUbedP+6+TbIVV1/J6gQ2h05iz5PFaCQ1xnJTvikyivaewzFVmBskJ0EQBEEQBEH4qw+pDd6XKMze04defEmSSNz0B1HDZqN5FQmAZasGOM4bSaXWzbl7926mMWZmZowcOZLhw4djbm6eKV7ykV1EzhiEOvwpABa1GuE8djEmRbyy9d4kSeL+tS2c2z2U5ISXAHiWakHttguxdSqSrdj6epoUxPKg/tyIPgZAPnMPfvaeTw2X1uIZP0EQBEEQBMGgRGFmQPpefE1CEjGT/IlbtB40GmRmpjzu3ISxN09x6cqVLMd4enqycOFCWrRokakI0aamELN8OnGr5iCplGBsjH2XQTj0GYfc0ipb71GZmsClQ5O4fWoxWq0ahZEpFRuOpGLDERiZ5H63SpIkzrzaxsqQIUSlPQegklMTevkspoBl8VzPRxAEQRAEQRBAFGYGld2Ln37vAZH9ppB64iIAco8CHGlRiQm/rSMyMjLLMU2bNmXRokUUL565CFE+CSVy2kBSTu4HwChfAZxGzMGq2bfZ7ijFvAri1PZ+PL9/HABrB3dqt1lAkTKG6ValqZP5PWwG2x/NQS0pMZIZ08ZjMB2LjsXcKHvFqCAIgiAIgiB8KFGYGVDGxY+JiMLe2VGvGJIkkbTtIFFDZqJ+/goAZcOqLHVVsHzTBrRabaYxJiYmDB06lNGjR2NpaZnp9aQ/9xI1bSCqZ2EAmFepi/P4AExLlNIrx7/mGnpzO2d3DyEp9hkAhX2aUKfdIuxdSmQrtr6eJ99neVB/rkUdBsDJrCA9vOdTK197Mb1REARBEARByDWiMDOgjIu/oshoGi/5HvcmvnrH0iYlEzNtGbHz1oBKhczEmBedmzAu6BJnzp/LckzBggWZP38+7dtnLkK06WnE/TqXmOXTkdJSQaHArpMfDv0nobC21TtPAFV6MlePTuf68bloNUrkCmPK1R9M5cZjMTHN/W6VJElcitzLiqCBvEp9BEBZh/r09vHH3bpkrucjCIIgCIIgfH5EYWZAGRd/Fj0xxwTPlqWovbAdtp76dc8AlA8eE9l/CimHzgCgKJCPE22rM277Rl6+fJnpeGNjYx48eIC7u3uW8VQvnhA5YzDJR3bq4jm64DRsNtatOyPL5sZ3sRH3ObNzIE8CDwJgZVeQWq3nUax8B4N0q9I1qWwLm8W2R7NQatNQyIxo5T6A74qNx9Lo4/xwCIIgCIIgCEJWRGFmQBkX/2CfDTz85SpatRaFmTEVRzak4vAGGJmb/HuQLEiSRPKeY0QOmoH6sW6BC02divxS1BL/DetQq9Vvjh05ciQzZsz415jJZ48QOaU/qkchAJiVr47z+ADMSlbQK8e/5vr47j5O7xxAQrSuW1WweH3qtFuMo1v2pk7q63XKY1YED+JCxG4A7E3y8ZPXLBq6/SCmNwqCIAiCIAg5QhRmBvTXi69+nsypfjt4/ud9AKzdHai9sC1FWpXWuxjQpqYRO3slsTN/QUpLByMjIr5vzIQntzh+8gQFCxYkODg4y+fMsiIplcSuW0jMkslIKckgk2HbsReOA6eisHPQK8cMamUq147P5tqxmWhUacjkCsrW6U/VphMxMTdMt+pK5EFWBA3gRcoDAEra16K3jz9FbcoZJB9BEARBEATh0yUKMwP6+8WXJInQbTc4O2Q3Sc/jAHBv6kudRW2xK+6i93lUj54ROWg6yXt0KyLK8zlx/pva2DWsRYsWLbIck5KSQmhoKGXKlMn0mvrVCyJnDyNp3xZdPDsHnAZPx6ZDd2QKhd55AiREP+bMrkGE3d4NgIV1Pmq0mo135c4G6VaptEp2PV7A5oeTSdekIEdO00I9+bHENKyN7XM9H0EQBEEQBOHTJAozA3rXxVclp3N1+lGuzz2OVqlBbqyg/JD6VB7bGGNLU73Pl3zwFJEDpqF68BgAs1oVcQkYj2lZn0zHTpgwgalTp9KnTx8mT56MvX3mIiTl0kkip/RDeV+3sbVpqUq4TAjArGxVvXPM8CTwEKd3DiAuQtdBdC1Si7rt/XEuWC7bsfURkfqUX0OGcfrVVgBsjJ3o6jWTLwv8hFyWvWftBEEQBEEQBEEUZgb0bxc/9n4EZwbu5MnBQACsCtpRa15rinUor//0xnQlcfNXEzN1GVJKKsjl2Pb5DsfJA1DY61ZbDAsLw9fXl/T0dACcnJyYOXMmP/30E/K/LfghqdXEbVpCzKLxaJMSALBp9xOOQ2di5Kh/lw9Ao0rn5smFXD48GbUyBZlMTqmavajWbApmltmbOqmv29EnWRLUl6dJur+T4jaV6FtyKV62lQ2SjyAIgiAIgvBpEIWZAb3PxZckiUd773JmwA4SHscAULB+ceoGtMfB11Xvc6uevSRq6EyStupWRFQ42eM4cyg2P7WjVevW7N27N9OYKlWqEBAQQOXKmYsQdeQrouaNInHnWgDk1rY4DpiC7Xe9kRkZ6Z0nQGLsM87uHkroDV23yszSiRotZuBbrWu2V4bUh1qr4o8n/mx6OIkUdQIyZDQu2J0fS0zDzsQ51/MRBEEQBEEQ/vtEYWZAH3Lx1alKrs0+zrWZx9CkqZAbySnTvy5VJzTBxMZc7xxS/rxAZL8pKANDATCuXJrV5fMxZ/0a0tLSMh0vk8no3r0706dPx8nJKdPrqTcuEDmpL+mBNwAw8SqNy/glmFeurXeOGZ4/OMmp7X7EvLwHgEvhStTtsIT87lWyHVsfMemvWB0ynOPhGwCwMrLjh+JT+apwLxSy7D1rJwiCIAiCIHxeRGFmQPpc/ITH0ZweuJNHe+4AYJHPmppzWuHVqbLe0xsllYo4/w3ETPRHm6hbbTHu64ZMT3rC7v37sxxjb2/P1KlT6dmzJ4q/LfghaTTEb11J9PzRaONjAbBu+T1Ow2ZjlM9NrxwzaDQqbp8O4PLBiSjTEkAmo2T17lRvNg1za8N0q+7GnmV5UH8eJuiK0aI25enls5hS9rUMko8gCIIgCILw3yMKMwPKzsV/ciiQ0/13EPcgEgDXWkWo698e53IF9c5H/TKCqOGzSdz4BwBye1tufvcFo4/u5v79+1mOKVeuHEuWLKFGjRqZXtPERBG1YAwJW1eCJCGztMLRbwJ2nfsjM9Fvj7YMyQmvOLdnOCFXdN0qU3M7qjWbSqlavZDLc79bpZE
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.plot(X[y==0, 0], X[y==0, 1], \"bs\")\n",
"plt.plot(X[y==1, 0], X[y==1, 1], \"g^\")\n",
"\n",
"zz = y_proba[:, 1].reshape(x0.shape)\n",
"contour = plt.contour(x0, x1, zz, cmap=plt.cm.brg)\n",
"\n",
"# Solve theta^T x = 0 to determine boundary\n",
"left_right = np.array([2.9, 7]) \n",
"boundary = -(log_reg.coef_[0][0] * left_right + log_reg.intercept_[0]) / log_reg.coef_[0][1]\n",
"\n",
"plt.clabel(contour, inline=1, fontsize=12)\n",
"plt.plot(left_right, boundary, \"k--\", linewidth=3)\n",
"plt.text(3.5, 1.5, \"Not Iris-Virginica\", fontsize=14, color=\"b\", ha=\"center\")\n",
"plt.text(6.5, 2.3, \"Iris-Virginica\", fontsize=14, color=\"g\", ha=\"center\")\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.axis([2.9, 7, 0.8, 2.7]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercises:** *You can now complete Exercise 2 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Softmax regression\n",
"\n",
"Can generalise logistic regression to classify multiple classes."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Softmax score\n",
"\n",
"Consider the softmax score function for class $k$:\n",
"\n",
"$$s_k(x) = \\left(\\theta^{(k)}\\right)^{\\rm T} x .$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"**Important note**: each class $k$ has its own score and set of parameters $\\theta^{(k)}$, for $K$ classes (i.e. $k=1,\\ldots,K$).\n",
"\n",
"Define:\n",
"- Parameter matrix: $\\Theta_{K \\times n} = [ \\theta^{(1)},\\ \\theta^{(2)},\\ ...,\\ \\theta^{(K)}]^{\\rm T}$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Softmax function\n",
"\n",
"Predictions are then given by the softmax function $\\sigma_k(s(x))$ for each $k$:\n",
"\n",
"$$ \n",
"\\hat{p}_k = \\sigma_k(s(x)) = \\frac{\\exp\\left(s_k(x)\\right)}{\\sum_{k^\\prime=1}^K \\exp\\left(s_{k^\\prime}(x)\\right)}\n",
".\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Normalised such that\n",
"- $\\sum_k \\hat{p}_k = 1$\n",
"- $0 \\leq \\hat{p}_k \\leq 1$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Predictions\n",
"\n",
"Can then make class predictions based on which class has the highest predicted probability, i.e.\n",
"\n",
"$$ \n",
"\\hat{y}\n",
"= \\arg\\max_k \\hat{p}_k \n",
"= \\arg\\max_k s_k(x)\n",
"= \\arg\\max_k \\left(\\theta^{(k)}\\right)^{\\rm T} x\n",
",\n",
"$$\n",
"\n",
"where we recall \n",
"$\\hat{p}_k = \\sigma_k(s(x)) = \\frac{\\exp\\left(s_k(x)\\right)}{\\sum_{k^\\prime=1}^K \\exp\\left(s_{k^\\prime}(x)\\right)}\n",
"\\quad \\text{and} \\quad \n",
"s_k(x) = \\left(\\theta^{(k)}\\right)^{\\rm T} x.$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Cost function\n",
"\n",
"Generalization of the logistic regression cost function is given by the *cross-entropy* (measure of similarity of probability distributions):\n",
"\n",
"$$\n",
"C(\\Theta) = -\\frac{1}{m} \\sum_{i=1}^m \\sum_{k=1}^K y_k^{(i)} \\log \\left(\\hat{p}_k^{(i)}\\right)\n",
".\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"For the case $K=2$, the cost functions reduces to the standard cost function for logistic regression."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Minimising the cost function\n",
"\n",
"Can solve by gradient descent.\n",
"\n",
"Derivative of cost function given by\n",
"\n",
"$$\n",
"\\frac{\\partial C}{\\partial \\theta^{(k)}}\n",
"= \\frac{1}{m} \\sum_{i=1}^m \\left(\\hat{p}_k^{(i)} - y_k^{(i)} \\right) x^{(i)} .\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Example of softmax regression"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:59.076992Z",
"iopub.status.busy": "2024-01-10T00:19:59.076312Z",
"iopub.status.idle": "2024-01-10T00:19:59.088101Z",
"shell.execute_reply": "2024-01-10T00:19:59.087548Z"
}
},
"outputs": [],
"source": [
"X = iris[\"data\"][:, (2, 3)] # petal length, petal width\n",
"y = iris[\"target\"] # consider all three target classes\n",
"\n",
"C = 10\n",
"softmax_reg = LogisticRegression(multi_class=\"multinomial\", solver=\"lbfgs\", C=C, random_state=42)\n",
"softmax_reg.fit(X, y);"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:59.091075Z",
"iopub.status.busy": "2024-01-10T00:19:59.090368Z",
"iopub.status.idle": "2024-01-10T00:19:59.115277Z",
"shell.execute_reply": "2024-01-10T00:19:59.114655Z"
}
},
"outputs": [],
"source": [
"x0, x1 = np.meshgrid(\n",
" np.linspace(0, 8, 500).reshape(-1, 1),\n",
" np.linspace(0, 3.5, 200).reshape(-1, 1),\n",
" )\n",
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
"\n",
"y_proba = softmax_reg.predict_proba(X_new)\n",
"y_predict = softmax_reg.predict(X_new)\n",
"\n",
"# Select contours to plot\n",
"# zz1 = y_proba[:, 0].reshape(x0.shape)\n",
"zz1 = y_proba[:, 1].reshape(x0.shape)\n",
"# zz1 = y_proba[:, 2].reshape(x0.shape)\n",
"zz = y_predict.reshape(x0.shape)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:59.118315Z",
"iopub.status.busy": "2024-01-10T00:19:59.117863Z",
"iopub.status.idle": "2024-01-10T00:19:59.387990Z",
"shell.execute_reply": "2024-01-10T00:19:59.387264Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1cAAAHKCAYAAADvpbFEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1frA8e/M1mySTYUkhFBDCSWhSFdBlG7Bdu2AgL1cu3IVu6L32vWnYkGwgAU7vaNIh9BCJ4FASALpbfvM749NQoetEML5PE8eIMyeObOE3X3nPed9JVVVVQRBEARBEARBEAS/yOd6AoIgCIIgCIIgCPWBCK4EQRAEQRAEQRACQARXgiAIgiAIgiAIASCCK0EQBEEQBEEQhAAQwZUgCIIgCIIgCEIAiOBKEARBEARBEAQhAERwJQiCIAiCIAiCEAAiuBIEQRAEQRAEQQgAEVwJgiAIgiAIgiAEgAiuBEEQBEEQBEEQAqDOBVeffPIJqampmM1mzGYzvXr1Yvbs2ac8fvLkyUiSdMyX0Wg8izMWBEEQBEEQBEEA7bmewPEaN27MG2+8QatWrVBVlSlTpnDNNdeQnp5O+/btT/oYs9nMjh07av8sSdLZmq4gCIIgCIIgCAJQB4Orq6666pg/v/baa3zyySesXLnylMGVJEnEx8efjekJgiAIgiAIgiCcVJ0Lro7mcrn46aefqKyspFevXqc8rqKigqZNm6IoCl26dOH1118/ZSAGYLPZsNlstX9WFIWioiJiYmJE1ksQBEEQBEEQLmCqqlJeXk6jRo2QZe92UdXJ4Grz5s306tULq9VKWFgYv/76K+3atTvpsW3atGHSpEmkpqZSWlrKW2+9Re/evcnIyKBx48YnfcyECRN46aWXgnkJgiAIgiAIgiCcx/bv33/KeOJUJFVV1SDNx2d2u53s7GxKS0uZPn06X3zxBUuXLj1lgHU0h8NBSkoKt9xyC6+88spJjzk+c1VaWkqTJk3IyvqC8HBTwK5DEITzxzffDMN4x9RzPQ1BEARBEM4xS7mFx5o/RklJCREREV49tk4GV8e74ooraNmyJRMnTvTo+BtvvBGtVsu0adM8Or6srIyIiAgKCqZiNovgShAuVJMmXYNx9JRzPQ1BEARBEM4hS5mF+2Lvo7S0FLPZ7NVj6+SywOMpinJMpul0XC4XmzdvZujQoUGelSAI9c2IO37no1evR+n5J45SCSTQhqkYGyqENnehj6jz96IEQRAEQTiH6lxwNW7cOIYMGUKTJk0oLy9n6tSpLFmyhLlz5wIwYsQIEhMTmTBhAgAvv/wyPXv2JDk5mZKSEv73v/+xb98+xo4dey4vQxCE84StCv76Ueav7zVsXylhq1oGRJ302JBGLqK7OIjp6SDuMjtRXZzImrM7X0EQBEEQ6q46F1wdOnSIESNGkJubS0REBKmpqcydO5cBAwYAkJ2dfUzVjuLiYu666y7y8vKIioqia9euLF++3KP9WYIgXLhUFf7+Uebzx7WUHjpSJTQkXKVBkkpVlBNUcJZLWPI02A7LWA5qyDmoIWeGu1G5PlohYbCNxtfYSBhsQxtyrq5GEARBEIS64LzYcxVsYs+VIFxYbFXwzigtK35zp50aNlUZfJeL7sMUGqeoyPKJ+6/sJRKlGVoKV+s4/I+OQ0v1OEqP3OjRhiskXW+j+R0WGlzsQHR1EARBEITzU73fcyUIghAodiu8er2OjQtltDqVf41zcf1TLnT60z9OH6nSoI+DBn0ctH0UFCcUrNSR84eB/T8bqdqvIWtyCFmTQwhv5aTlGAvNR1kwRF/w968EQRAE4YIhMleIzJUgXChUFd68RcvyXzQYQ1VenOGgXZ9TvwROpcSzcRU4/I+OrG9C2D/dgLPCndHSGFWa3mahzUNVRLRzBeISBEEQBEEIMn8yV961HBYEQTiPLZkqs/wXDVq9yvjfTh9YAdxKpEfjSjI0vMRBj8/KuCa7gG6flhGZ5sBllcj80sTsTrEsvSqS/CU6xO0sQRAEQai/RHAlCMIFofQwfPGEeyX0LeNddOzrWZTjaYBVQxem0nK0hUGri+i/qIjGw61IskruXAOLB0azoG8UOTP1IsgSBEEQhHpIBFeCIFwQfnlbQ3mhRLOOCtc+5t0SPW8DLABJgoYXO7j4x1KGbS0k+d4qZINK4Uo9f18bxdzu0ez/1SCCLEEQBEGoR0RwJQhCvVdRAnM+d1cGvOMVF1qd92P4EmDVCGvh4qIPyrlqVwFtH69EG6ZQslHHPzdFMq9ntMhkCYIgCEI9IYIrQRDqvb++l7GUSySlKFw0RPF5HH8CLICQeIVOEyq4alcB7Z6pQBumUJyu4+9ro1hwSRT5i32I+gRBEARBqDNEcCUIQr238Gt31mrgGMXv/lPWSSP9no8hRiX15Uqu3OHOZGlCVApX61k8KJolwyIpShddMgRBEAThfCSCK0EQ6rX8vbBrrYwsq/S92f9y6KNH/x6QAAvA2ECl04QKrtxRQKv7q5B1KnnzDczrEcPy2yKoyNQE5DyCIAiCIJwdIrgSBKFeWzvb/TKX0kclsuE5nswphMQrdH2vnKGbC2l6iwUkleyfjMzqGEP6k2HYivxMtwmCIAiCcFaI4EoQhHpt40L3y1zXQb7vtTpbwlq46DWljEGri4i7wobikNjxfigz2say/T0TLvu5nqEgCIIgCKcjgitBEOotRYGMf9wvcx36Bi64CuTSwJOJSnPSb2YJfWcUE9HBgaNEZsNT4czpFMOB30X5dkEQBEGoq0RwJQhCvZW3R6K8UEJvVEnuEtiIJNgBliRBwkA7g9YU0W1iKcY4F+W7tSy7MZLFg6Io2SSKXgiCIAhCXSOCK0EQ6q1d69x7lZqlqj71tjqTYAdYALIGWt5pZdjWQto9XYFsUDm0RM/c7tGsfTAc62GxH0sQBEEQ6goRXAmCUG/t3ewOPFp0Ct46urMRYAHowlVSX6lk2JYCkq6zoioSuz8zMbN9LDs/CkFxBn0KgiAIgiCcgQiuBEGot/ZvdQdXTdoFt5jF6NG/B3X8o4U2VejzfSn9FxYRmerej7X+MTNzu0WTv0Q0IRYEQRCEc0kEV4Ig1Fs5u9zBVeM2wa8AcSuRQT/H0Rpe4mDgqiIu+qgMfbRCaYaOxQOj+efWCKoOiJd2QRAEQTgXxDuwIAj1kssF+Vnu4KpR8tkpr3e2AyxZA8l3Wxi2tYDke6uQZJX90939sba+acJlO6vTEQRBEIQLngiuBEGol4pzwemQ0GhVYhqfvfOe7QALwBCtctEH5QxcWURsLzvOSplN48OZ0yWGvAX6sz4fQRAEQbhQieBKEIR66VC2O2sV0xg0mnM8mbMkqpOTy5cU0/Or6tLtu7QsGRrFPzdHULlfvNwLgiAIQrCJd1tBEOqlohx3cBXb+MLquCtJ0Ow2K0O3FNL6weqlgr8YmZ0aw/Z3TCiOcz1DQRAEQai/RHAlCEK9VJjrDq6i489+cHUulgYeTx+h0uWdcgatLiK2j3up4IZnwpnbLYZDf4mqgoIgCIIQDCK4EgShXio95P416hwEV1A3AiyAyFQnly8spvtnpehjFEq3all0RTQrR5tFA2JBEARBCDARXAmCUC+VVgcOEQ3O3RzqSoAlydBilJVhGQUk310Fksreb0OY1SGW3Z+HoAa3DZggCIIgXDBEcCUIQr1UXuj+1Rx7bvdcWSeNPKfnP5ohWuWij8oZ8HcRkWkO7MUyax8ws+DSKIo3as/19ARBEAThvCfeTQVBqJfKityZq7DIczeHyoM2LjFNZPGQBsgRG1ALnQBIITJSQx1Skh5NWyNymgm5tQFJPjvL9GK6Oxm4oojdE0PY9HwYhav1zOsZTesHq+jwQiW6sAurCIggCIIgBIoIrgRBqJcqS9y/hkad/UBh/4JCNr2/n/3zCqH69GdceRelQXtJONorwtEOjURuEtz+VLIWWj9gofFwG+mPh7P/FyM73g8l+2cjXd8rp/HVogOxIAiCIHhLBFeCINRLlaXuLFBoxNk7p6XAzj+P7WT39/m132vQNZy4nhFEJJtYuaUX+n5/oVo
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.plot(X[y==2, 0], X[y==2, 1], \"g^\", label=\"Iris-Virginica\")\n",
"plt.plot(X[y==1, 0], X[y==1, 1], \"bs\", label=\"Iris-Versicolor\")\n",
"plt.plot(X[y==0, 0], X[y==0, 1], \"yo\", label=\"Iris-Setosa\")\n",
"\n",
"from matplotlib.colors import ListedColormap\n",
"custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])\n",
"\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
"contour = plt.contour(x0, x1, zz1, cmap=plt.cm.brg)\n",
"plt.clabel(contour, inline=1, fontsize=12)\n",
"plt.xlabel(\"Petal length\", fontsize=14)\n",
"plt.ylabel(\"Petal width\", fontsize=14)\n",
"plt.legend(loc=\"center left\", fontsize=14)\n",
"plt.axis([0, 7, 0, 3.5]);"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}