spce0038-machine-learning-w.../week3/lectures/Lecture07_TrainingIII.ipynb

1214 lines
410 KiB
Plaintext
Raw Normal View History

2025-02-22 19:17:44 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Lecture 7: Training III"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"![](https://www.tensorflow.org/images/colab_logo_32px.png)\n",
"[Run in colab](https://colab.research.google.com/drive/151rKfOzxrWMK7fiViIGsosquddN07yJg )"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:44.467687Z",
"iopub.status.busy": "2024-01-10T00:19:44.467138Z",
"iopub.status.idle": "2024-01-10T00:19:44.475581Z",
"shell.execute_reply": "2024-01-10T00:19:44.475048Z"
},
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last executed: 2024-01-10 00:19:44\n"
]
}
],
"source": [
"import datetime\n",
"now = datetime.datetime.now()\n",
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Polynomial regression"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"So far we have considered only linear regression. Polynomial regression can also be performed with a model that is linear (in the parameters)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:44.516067Z",
"iopub.status.busy": "2024-01-10T00:19:44.515424Z",
"iopub.status.idle": "2024-01-10T00:19:44.928255Z",
"shell.execute_reply": "2024-01-10T00:19:44.927538Z"
}
},
"outputs": [],
"source": [
"# Common imports\n",
"import os\n",
"import numpy as np\n",
"np.random.seed(42) # To make this notebook's output stable across runs\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Example data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:44.931698Z",
"iopub.status.busy": "2024-01-10T00:19:44.931400Z",
"iopub.status.idle": "2024-01-10T00:19:44.935887Z",
"shell.execute_reply": "2024-01-10T00:19:44.935221Z"
}
},
"outputs": [],
"source": [
"import numpy.random as rnd\n",
"np.random.seed(42)\n",
"m = 100\n",
"X = 6 * np.random.rand(m, 1) - 3\n",
"y = 0.5 * X**2 + X + 2 + np.random.randn(m, 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:44.938729Z",
"iopub.status.busy": "2024-01-10T00:19:44.938267Z",
"iopub.status.idle": "2024-01-10T00:19:45.153988Z",
"shell.execute_reply": "2024-01-10T00:19:45.153313Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAHVCAYAAAAtqRArAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1TElEQVR4nO3deZRcdZk//qe7mYRA6NbIkoQOCZhGjiyCLBGQ2KISBMmMyjbqCENzEuawDIrANwoHwqoHF4K4JGMPizgqwjgCOgPiaSZsA6ig4oBJ5tiQSgJIJN0Ekga6+/dH/bJ3Kr1U1b117+t1Tk5Rt6q6nypu3b7v+9nq+vv7+wMAAIAB1SddAAAAQJoJTQAAACUITQAAACUITQAAACUITQAAACUITQAAACUITQAAACUITQAAACUITQAAACUITQAAACVULTStXr06Lrvssjj22GNj3LhxUVdXFzfffPOAz33mmWfi2GOPjbFjx8a4cePiH/7hH+Ivf/lLtUoFAABYb7tq/aKXX345rrjiithjjz3iPe95TzzwwAMDPq9QKMT06dOjqakprrnmmli9enV89atfjT/84Q/x+OOPx6hRo6pVMgAAQPVC04QJE2LFihUxfvz4+PWvfx2HHnrogM+75ppr4rXXXovf/OY3sccee0RExGGHHRYf+chH4uabb45Zs2ZVq2QAAIDqdc8bPXp0jB8/fpvPu/POO+NjH/vY+sAUEfHhD3849t5777j99tsrWSIAAMAWUjURxLJly+Kll16KQw45ZIvHDjvssHjyyScTqAoAAMizqnXPG4wVK1ZERLEr3+YmTJgQf/3rX6OnpydGjx69xeM9PT3R09Oz/n5fX1/89a9/jXe84x1RV1dXuaIBAIBU6+/vj1dffTUmTpwY9fVDbzdKVWhas2ZNRMSAoWj77bdf/5yBHr/22mtj7ty5lS0QAACoWUuXLo3m5uYhvy5VoWnMmDEREZu0GK2zdu3aTZ6zuTlz5sTnP//59fe7urpijz32iKVLl0ZjY2MFqgUAAJKwcGHECSdsuf2eeyKOOmrL7d3d3TFp0qTYaaedhvX7UhWa1nXLW9dNb2MrVqyIcePGDdjKFFFsnRroscbGRqEJAAAy5MADI+rrI/r6NmxraIh4z3siSp36D3fYTqomgth9991jl112iV//+tdbPPb444/HgQceWP2iAACAVGlujliwoBiUIoq38+cXt1dCqkJTRMQnP/nJuOeee2Lp0qXrt/3qV7+KRYsWxUknnZRgZQAAQFq0tUV0dkZ0dBRv29oq97vq+vv7+yv34zd14403xqpVq2L58uXxne98Jz7xiU/EQQcdFBER5557bjQ1NcXSpUvjoIMOire97W3xz//8z7F69eq47rrrorm5OZ544omtds/bXHd3dzQ1NUVXV5fueQAAkGMjzQZVDU1TpkyJ5557bsDH/vznP8eUKVMiIuKPf/xjfP7zn4+HHnooRo0aFccff3x87Wtfi912223Qv0toAgAAImosNFWT0AQAAESMPBukbkwTAABAmghNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAAAAJQhNAABA4gqFiI6O4m3aCE0AAECi2tsjJk+OOPro4m17e9IVbUpoAgAAElMoRMyaFdHXV7zf1xcxe3a6WpyEJgAAIDGLF28ITOv09kYsWZJMPQMRmgAAgMS0tETUb5ZKGhoipk5Npp6BCE0AAEBimpsjFiwoBqWI4u38+cXtabFd0gUAAAD51tYWMWNGsUve1KnpCkwRQhMAAJACzc3pC0vr6J4HAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAABQgtAEAADDVChEdHQUb8kuoQkAAIahvT1i8uSIo48u3ra3J10RlSI0AQDAEBUKEbNmRfT1Fe/39UXMnq3FKauEJgAAGKLFizcEpnV6eyOWLEmmHipLaAIAgCFqaYmo3+xMuqEhYurUZOqhsoQmAAAYoubmiAULikEpong7f35xO9mzXdIFAABALWpri5gxo9glb+pUgalQKHZbbGnJ3mehpQkAAIapuTmitTU9ISGpKdCzPpOg0AQAABlQ6eCytUCWh5kEhSYAAKhxlQ4upQJZHmYSFJoAAKDGVTK4bCuQ5WEmQaEJAABqXCWDy7YCWR5mEhSaAACgxlUyuAwmkLW1RXR2Fsc8dXYW72dJKkPT4sWL49RTT43m5ubYYYcdYp999okrrrgiXn/99aRLAwCAVKpUcBlsIGtuLgapxYuzNQlERERdf39/f9JFbGzp0qVxwAEHRFNTU5x11lkxbty4ePTRR+Pmm2+OmTNnxs9+9rNB/Zzu7u5oamqKrq6uaGxsrHDVAACQbYVC6TWp2ts3jH2qry8GrbS0OI00G6Rucdvvf//7sWrVqnjooYdi3333jYiIWbNmRV9fX9x6663xyiuvxNvf/vaEqwQAgHxpbt56d7+tTRYxY0Y2xjalrnted3d3RETstttum2yfMGFC1NfXx6hRo5IoCwAAKiqphWnLIevTjqcuNLW2tkZERFtbWzz11FOxdOnS+PGPfxzf+c534rzzzosdd9wx2QIBAKDMKr0wbaVlfdrx1I1pioi46qqr4pprrok1a9as3/alL30prrrqqq2+pqenJ3p6etbf7+7ujkmTJhnTBABAqhUKxaC0cUtNQ0NxModa6trW3l7sktfbu2GyCGOaKmjKlCkxffr0+OQnPxnveMc74uc//3lcc801MX78+DjnnHMGfM21114bc+fOrXKlAAAwMqW6ttVSaGprK45hKjVZRK1KXUvTj370ozjjjDNi0aJF0bzRJ/2P//iPcfvtt8fzzz8f73jHO7Z4nZYmAABqUVZamtJspC1NqRvT9O1vfzsOOuigTQJTRMTMmTPj9ddfjyeffHLA140ePToaGxs3+QcAAGlXyYVpKY/Udc978cUXB5xS/M0334yIiLfeeqvaJQEAQEVluWtbFqSupWnvvfeOJ598MhYtWrTJ9h/+8IdRX18fBxxwQEKVAQBA5TQ3R7S2Vicw1fL05klIXWi68MILo7e3N4466qi48sor49vf/nYcd9xx8R//8R9xxhlnxMSJE5MuEQAAalZapzdPc5BL3UQQERGPP/54XH755fHkk0/GypUrY88994zTTjstLrroothuu8H1KBzpYC8AAMiatE460d4eMWtWsa76+uIYr3JOVz7SbJDK0FQOQhMAAGyqo6PYwjTQ9tbWqpcTEdUJcpmbPQ8AAKiMlpZiS87GGhqKk08kpdQ6VWkhNAEAQE6kcXrzNAa5zQlNAACQI21txa5vHR3F23KOHRqONAa5zRnTBAAAJK5QqNw6VSPNBqlb3BYAAMif5uZ0tS5tTPc8AACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAACAEoQmAAAgEYVCREdH8TbNhCYAAKDq2tsjJk+OOPro4m17e9IVbZ3QBABA5tRKC0ZeFQoRs2ZF9PUV7/f1Rcyend7/X0ITAACZUkstGHm1ePGGwLROb2/EkiXJ1LMtQhMAAJlRay0Yg5HFVrOWloj6zZJIQ0PE1KnJ1LMtQhMAAJlRay0Y25LVVrPm5ogFC4pBKaJ4O39+cXsa1fX39/c
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,5))\n",
"plt.plot(X, y, \"b.\")\n",
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.axis([-3, 3, 0, 10]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Clearly a straight line will not fit the data well."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Construct new features\n",
"\n",
"Can use a linear model by constructing additional features that are powers of existing features:\n",
"\n",
"$$\\hat{y} = \\theta_0 + \\theta_1 x_1 + \\theta_2 x_1^2 + \\theta_3 x_2 + \\theta_4 x_2^2 + \\theta_5 x_1 x_2 + ... $$\n",
"\n",
"Model remains linear in the parameters $\\theta_j$.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Generate polynomial features"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.157478Z",
"iopub.status.busy": "2024-01-10T00:19:45.157015Z",
"iopub.status.idle": "2024-01-10T00:19:45.487951Z",
"shell.execute_reply": "2024-01-10T00:19:45.487213Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(array([-0.75275929]), array([-0.75275929, 0.56664654]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.preprocessing import PolynomialFeatures\n",
"poly_features = PolynomialFeatures(degree=2, include_bias=False)\n",
"X_poly = poly_features.fit_transform(X)\n",
"X[0], X_poly[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.491035Z",
"iopub.status.busy": "2024-01-10T00:19:45.490576Z",
"iopub.status.idle": "2024-01-10T00:19:45.495240Z",
"shell.execute_reply": "2024-01-10T00:19:45.494597Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"((100, 1), (100, 2))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, X_poly.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Fit model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.498108Z",
"iopub.status.busy": "2024-01-10T00:19:45.497659Z",
"iopub.status.idle": "2024-01-10T00:19:45.555973Z",
"shell.execute_reply": "2024-01-10T00:19:45.555227Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(array([1.78134581]), array([[0.93366893, 0.56456263]]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"lin_reg = LinearRegression()\n",
"lin_reg.fit(X_poly, y)\n",
"lin_reg.intercept_, lin_reg.coef_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Parameters are close to the model used to generate the data (2, 1 and 0.5 respectively)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Predictions"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.559266Z",
"iopub.status.busy": "2024-01-10T00:19:45.558700Z",
"iopub.status.idle": "2024-01-10T00:19:45.735843Z",
"shell.execute_reply": "2024-01-10T00:19:45.735234Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA00AAAHVCAYAAAAtqRArAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABmFUlEQVR4nO3dd3jT1R7H8U9aSpktlg1lU/DKRgRcLJEpCAqoiIIUwYUT9XLFy3DgQBRFUbAyrgvEgYAIiCB7qIiiIKCslCUgLaulI/ePYxpKBx1Jfhnv1/PkCb9fRr+tsc0n55zvsTkcDocAAAAAANkKsboAAAAAAPBlhCYAAAAAyAWhCQAAAAByQWgCAAAAgFwQmgAAAAAgF4QmAAAAAMgFoQkAAAAAckFoAgAAAIBcEJoAAAAAIBeEJgAAAADIhddC06lTpzR69Gh16dJFUVFRstlsmjFjRrb33bZtm7p06aJSpUopKipKd9xxh/766y9vlQoAAAAAGYp46wsdPXpU48aNU/Xq1dWkSROtWLEi2/vZ7Xa1adNGkZGRev7553Xq1ClNmDBBv/zyizZu3KiiRYt6q2QAAAAA8F5oqly5sg4ePKhKlSrp+++/1xVXXJHt/Z5//nmdPn1aP/zwg6pXry5Jatmypa6//nrNmDFDQ4cO9VbJAAAAAOC96Xnh4eGqVKnSRe/36aef6oYbbsgITJLUsWNH1atXT3PmzPFkiQAAAACQhU81goiPj9eRI0fUokWLLLe1bNlSmzdvtqAqAAAAAMHMa9Pz8uLgwYOSzFS+C1WuXFnHjx9XcnKywsPDs9yenJys5OTkjOP09HQdP35cZcuWlc1m81zRAAAAAHyaw+HQyZMnVaVKFYWE5H/cyKdC09mzZyUp21BUrFixjPtkd/v48eM1duxYzxYIAAAAwG/t379f0dHR+X6cT4Wm4sWLS1KmESOnpKSkTPe50MiRI/Xoo49mHCckJKh69erav3+/IiIiPFAtAAAAACusXCn16JH1/IIF0rXXZj2fmJioatWqqXTp0gX6ej4VmpzT8pzT9M538OBBRUVFZTvKJJnRqexui4iIIDQBAAAAAaRpUykkREpPd50LDZWaNJFye+tf0GU7PtUIomrVqipfvry+//77LLdt3LhRTZs29X5RAAAAAHxKdLQ0daoJSpK5fucdc94TfCo0SdLNN9+sBQsWaP/+/Rnnli1bph07dqhv374WVgYAAADAV8TGSnv2SMuXm+vYWM99LZvD4XB47ukzmzx5sk6cOKEDBw5oypQpuummm9SsWTNJ0vDhwxUZGan9+/erWbNmKlOmjB566CGdOnVKL7/8sqKjo7Vp06Ycp+ddKDExUZGRkUpISGB6HgAAABDECpsNvBqaatasqb1792Z72+7du1WzZk1J0q+//qpHH31Uq1evVtGiRdW9e3e98sorqlixYp6/FqEJAAAAgORnocmbCE0AAAAApMJnA5/qnucLUlJSlJaWZnUZwEWFhoYqLCzM6jIAAAACHqHpH4mJiTp69Gi2e0QBvio8PFzlypVjNBUAAMCDCE0ygSk+Pl6lSpVSuXLlFBYWVuAe7oA3OBwOpaSkKCEhQfHx8ZJEcAIAAMhJIQdGCE2Sjh49qlKlSik6OpqwBL9RvHhxlS5dWna7XUePHiU0AQAA5GTkyEI93Of2afK2lJQUJScnKzIyksAEv2Oz2RQZGank5GSlpKRYXQ4AAIDvmTlTiosr1FMEfWhyNn1gQT38lfO1SwMTAACAC2zeLN1zT6GfJuhDkxOjTPBXvHYBAACycfy4dPPNUlJSoZ+K0AQAAAAgsKSnSwMGSLt3m+PLLy/U0xGaAAAAAASWceOkRYvMv8uXl2bNKtTTEZrgk/bs2SObzaZBgwZlOt+uXTuPTkerWbOmatas6bHnBwAAgIctXCiNHWv+HRIiffyxFB1dqKckNCEjoJx/KVq0qKpVq6b+/fvr559/trpEtxk0aJBsNpv27NljdSkAAABwtz/+MNPynMaPlzp0KPTTsk8TMtSpU0cD/nmRnTp1SuvXr9dHH32kzz77TMuWLdPVV19tcYXSrFmzdObMGY89/7Jlyzz23AAAAPCgM2dM44cTJ8zxTTdJjz/ulqcmNCFD3bp1NWbMmEznRo0apeeee05PPfWUVqxYYUld56tevbpHn79OnToefX4AAAB4gMNhWotv2WKO69eXpk+X3LSsg+l5yNXw4cMlSZs2bZJk2lu3a9dO8fHxuvPOO1WpUiWFhIRkClQrV65Ujx49VK5cOYWHhysmJkajRo3KdoQoLS1NL774ourWratixYqpbt26Gj9+vNLT07OtJ7c1TfPmzVOnTp1UtmxZFStWTDVr1tQdd9yhrVu3SjLrlWbOnClJqlWrVsZUxHbt2mU8R05rmk6fPq3Ro0fr0ksvVbFixRQVFaXu3btrzZo1We47ZswY2Ww2rVixQh9++KGaNm2q4sWLq3LlynrooYd09uzZLI/59NNP1bZtW1WoUEHFihVTlSpV1LFjR3366afZfq8AAAA4z5Qp0v/+Z/5dsqT02WdSRITbnp6RJuTJ+UHl2LFjuvLKKxUVFaVbb71VSUlJivjnRTllyhTdf//9KlOmjHr06KEKFSro+++/13PPPafly5dr+fLlKlq0aMZzDR06VO+9955q1aql+++/X0lJSZo4caLWrl2br/oee+wxTZw4UVFRUerVq5cqVKig/fv365tvvtHll1+uhg0b6uGHH9aMGTO0ZcsWPfTQQypTpowkXbTxQ1JSkjp06KCNGzeqefPmevjhh3X48GHNnj1bixcv1kcffaS+fftmedzkyZP19ddf68Ybb1SHDh309ddf6/XXX9fRo0f1wQcfZNxvypQpuu+++1S5cmX17t1bZcuW1aFDh7Rx40Z9/vnnuvnmm/P1swAAAAgqa9ZIDz/sOp4+XbrsMvd+DUeASkhIcEhyJCQk5Hq/s2fPOn777TfH2bNnvVSZ79m9e7dDkqNz585Zbvvvf//rkORo3769w+FwOCQ5JDnuuusuR2pqaqb7/vrrr44iRYo4mjRp4jh69Gim28aPH++Q5JgwYULGueXLlzskOZo0aeI4depUxnm73e4oV66cQ5Jj4MCBmZ6nbdu2jgtftvPnz3dIcjRq1CjL101JSXEcOnQo43jgwIEOSY7du3dn+7OoUaOGo0aNGpnOjR071iHJcfvttzvS09Mzzv/444+OokWLOsqUKeNITEzMOD969GiHJEdkZKRj+/btGefPnDnjqFevniMkJMQRHx+fcb558+aOokWLOg4fPpylngu/n+zwGgYAAEErPt7hqFTJ4TAT9ByOxx7L9m55zQY5YaTpYlq0kA4dsrqK3FWqJH3/faGfZteuXRlrmk6fPq0NGzZo1apVKlasmJ577rmM+xUtWlQvvfSSQkNDMz3+nXfeUWpqqt544w2VLVs2021PPPGEJk6cqI8++kiPPfaYJNPUQZL++9//qmTJkhn3rVq1qh566CE9/fTTear7rbfekiRNmjQpy9ctUqSIKlasmKfnycnMmTMVFhamF154IdOIW7NmzTRw4EBNmzZNX3zxhe64445Mj3vooYdUv379jOPixYvrtttu09ixY/XDDz+oSpUqGbeFhYUpLCwsy9e+8PsBAADAP86dk/r0cb1X79BBeuEFj3wpQtPFHDokxcdbXYVX/PHHHxr7T0/7sLAwVaxYUf3799e///1vNWrUKON+tWrVUrly5bI8fv369ZKkxYsXZ9uFLiwsTNu3b8843vLPQr1rr702y32zO5eTjRs3Kjw8XG3bts3zY/IqMTFRf/75p/71r38pOpv+/u3bt9e0adP0008/ZQlNl2ez87TzOU44u7pIuvXWW/XEE0+oYcOG6t+/v9q3b69rrrkmY8ojAAAAsvHQQ9K6debf1aub/ZiKeCbeEJouplIlqyu4ODfV2LlzZ3399dcXvV9OIzfHjx+XpEyjUrlJSEhQSEhItgEsP6NDCQkJqlq1qkJC3N/XJDExMdd6KleunOl+58su9BT553/ktLS0jHMjRoxQ2bJlNWXKFL3yyiu
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_new = np.linspace(-3, 3, 100).reshape(100, 1)\n",
"X_new_poly = poly_features.transform(X_new)\n",
"y_new = lin_reg.predict(X_new_poly)\n",
"plt.figure(figsize=(10,5))\n",
"plt.plot(X, y, \"b.\")\n",
"plt.plot(X_new, y_new, \"r-\", linewidth=2, label=\"Predictions\")\n",
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.axis([-3, 3, 0, 10]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Learning curves"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How determine whether overfitting or underfitting?"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Overfitting with high degree polynomials"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.739185Z",
"iopub.status.busy": "2024-01-10T00:19:45.738769Z",
"iopub.status.idle": "2024-01-10T00:19:45.934670Z",
"shell.execute_reply": "2024-01-10T00:19:45.934005Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmUAAAE7CAYAAACc4/Y9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACYrklEQVR4nO2dd3iTZReH7yQddDNLW8qesmQICCqyZC+ZigoICCpDUBFxIYK4xYEDlI8hIjJEQFFQpiCgICgoG1poKbPQPZN8fzxN2rRpk3Qlbc99XbnSvCtP0iTv7z3nPL+jMRqNRgRBEARBEASnonX2AARBEARBEAQRZYIgCIIgCC6BiDJBEARBEAQXQESZIAiCIAiCCyCiTBAEQRAEwQUQUSYIgiAIguACiCgTBEEQBEFwAUSUCYIgCIIguAAiygRBEARBEFwAEWWCIAiCIAguQLGJsvj4eGbNmkXPnj2pWLEiGo2GpUuXWt32+PHj9OzZE19fXypWrMgjjzzCtWvXimuogiAIgiAIxY5bcT3R9evXee2116hRowa33347O3futLpdREQEHTt2JCAggHnz5hEfH8+7777L0aNH+eOPP/Dw8CiuIQuCIAiCIBQbxSbKgoODiYqKIigoiIMHD9KmTRur282bN4+EhAQOHTpEjRo1AGjbti333XcfS5cuZfz48cU1ZEEQBEEQhGKj2NKXnp6eBAUF2dxu3bp19O3b1yzIALp160aDBg1YvXp1UQ5REARBEATBabhUoX9kZCRXr17ljjvuyLGubdu2HD582AmjEgRBEARBKHqKLX1pD1FRUYBKdWYnODiY6OhoUlJS8PT0zLE+JSWFlJQU82ODwUB0dDSVKlVCo9EU3aAFQRAEQRAAo9FIXFwcISEhaLWOx71cSpQlJSUBWBVd5cqVM29jbf0bb7zB7Nmzi3aAgiAIgiAINrh48SKhoaEO7+dSoszLywvAIuJlIjk52WKb7MycOZOnn37a/DgmJoYaNWpw8eJF/P39zcurvF2FuV3mMuGOCfka4zNbnuFA5AH2jNljc9uH1j1Ecnoy64avy9dzCYIgCEJJJfT9UOJS4ph25zRe7fxqnttGxETQ5NMmrBu2jm51u5mX/3L2F4asHsLxiccJ8Q9xeAzfn/ieUetHET4tnPLlyue57Yq/VzBx80RuPn8TrSYzytX006YMbzKcMfVepkkTMBoz99Fq4dgxqFZNPY6NjaV69er4+fk5PFZwMVFmSlua0phZiYqKomLFilajZKCia9bW+fv7W4gyTTkNXr5eFsscwcPbA52Xzq793b3d0afr8/1cgiAIglBSMXoaQQMePh42z4N+Rj8oBz5+Phbbevt6Qznw8/fL17nUtL+/vz/+5fLe38vXy7xtVlGmLafF08eT227z54svYMIE0OtBp4OFC+G223IeK79lUy4lyqpVq0aVKlU4ePBgjnV//PEHLVq0KP5BCYIgCILgMHqDHlB1VqWFsWOhRw84cwbq1YN8ZCjzxKVmXwIMHjyYH374gYsXL5qXbdu2jVOnTjF06FAnjkwQBEEQBHvRG5UoMxgNTh5J4RIaCp06Fb4gg2KOlC1YsIBbt25x6dIlADZt2kRERAQAkydPJiAggBdeeIE1a9bQuXNnnnrqKeLj43nnnXdo1qwZjz76aHEOVxAEQRCEfGISY0ZKT6SsqClWUfbuu+8SHh5ufvzdd9/x3XffAfDwww8TEBBA9erV2bVrF08//TTPP/88Hh4e9OnTh/feey/XerL8YjQaSU9PR6/X271PgC6AoHJB5okHeVHJvRIp2hS7tnVl3N3d0el0zh6GIAiCUIIojenLoqZYRVlYWJhd2zVp0oQtW7YU6VhSU1OJiooiMTHRof36BPahW8VunD9/3ua2j9R4BCNGu7Z1ZTQaDaGhofj6+jp7KIIgCEIJwGg0miNk9qQvS3o0zWBQMzELiksV+hcXGjScP38enU5HSEgIHh4eds+U8Ij1ICk9idoVa9vcVntLi9FopFaFWgUcsfMwGo1cu3aNiIgI6tevLxEzQRAEwSZZhZgjgqskmr2fOQP33w9LlkCDBgU7VpkUZZ54YjAYqF69Ot7e3g7t65bshlajNZvZ5oXOQ4fRaLRrW1emSpUqhIWFkZaWJqJMEARBsImpyB9Kd/oyKkrNxjx3Djp3hvXrC3Y8l5t9WZzkpwWCo5T0kCyUzCsXQRAEwXnkN1JWkoiJgV69lCADqFUL6tcv2DHLZKRMEARBEISiw1TkD6XPEgMgORn694e//1aPa9aEn3+GfBr5mynTkTJBEARBEAofi0hZKUtfpqfDgw/C7t3qceXKsGVLZqulgiCirITx2Wef0bx5c3P7qPbt2/PTTz+Z1ycnJzNx4kQqVaqEr68vgwcP5sqVKxbHuHDhAn369MHb25vAwECmT59Oenp6cb8UQRAEoZRiUVNWitKXBgM89hh8/7167OMDmzdDw4aFc3wRZSWM0NBQ3nzzTQ4dOsTBgwfp0qULAwYM4N9//wVg2rRpbNq0iTVr1rBr1y4uXbrEoEGDzPvr9Xr69OlDamoqv//+O8uWLWPp0qW88sorznpJgiAIQinD0fRlSYimGY3w7LOwdKl67O6uCvvbtCm855CashJGv379LB6//vrrfPbZZ+zfv5/Q0FAWL17MypUr6dKlCwBLlizhtttuY//+/dx5551s3bqV//77j19//ZWqVavSokUL5syZw4wZM3j11Vfx8PBwxssSBEEQShH5TV9qcN2JZX/8AfPnq7+1Wli5Eu67r3CfQyJlJRi9Xs+qVatISEigffv2HDp0iLS0NLp162beplGjRtSoUYN9+/YBsG/fPpo1a0bVqlXN2/To0YPY2FhztE0QBEEQCkJpTF+2aweLF4NOB4sWwZAhhf8cEinLQmJaIieun8hzm6i4KJLSk0hOt9066WLMRQxGA/Gp8Xlu16hyI7zd7fdLO3r0KO3btyc5ORlfX1/Wr19P48aNOXLkCB4eHpQvX95i+6pVq3L58mUALl++bCHITOtN6wRBEAShoGSNlJWU2ZfRV7zhfCciIqBGdevbjBkDHTtCvXpFMwYRZVk4cf0ErRe1LvbnPTT+EK2CW9m9fcOGDTly5AgxMTGsXbuWUaNGsWvXriIcoSAIgiDYT9aaspJQL7Z4MUwfPxQMw6n9lZFFi2DsWLhxI+e2RSXIQESZBY0qN+LQ+EN5bmOKlNWpUMfm8UyRsprla9p8Xkfw8PCgXsanonXr1vz55598+OGHDB8+nNTUVG7dumURLbty5QpBQUEABAUF8ccff1gczzQ707SNIAiCIBSEkpS+jIiA8ePBaFAVXQaDhgkTwN9fRcY8BwyE5sUzFhFlWfB297YZsQr3CichLYHGVRrbPJ6/pz8Go4EGlQrYDMsGBoOBlJQUWrdujbu7O9u2bWPw4MEAnDx5kgsXLtC+fXsA2rdvz+uvv87Vq1cJDAwE4JdffsHf35/GjW2/JkEQBEGwRUly9D99WlldZEWvh1GjICkJ4le+z9lmy6FL0Y9FRFkRUhSzSGbOnEmvXr2oUaMGcXFxrFy5kp07d7JlyxYCAgIYO3YsTz/9NBUrVsTf35/JkyfTvn177rzzTgC6d+9O48aNeeSRR3j77be5fPkyL730EhMnTsTT07PQxysIgiCUPRy2xHCicKtfX82mzC7MkpLUvVejndRqcb5YxiKirIRx9epVRo4cSVRUFAEBATRv3pwtW7ZwX8a83Pnz56PVahk8eDApKSn06NGDTz/91Ly/Tqfjhx9+4IknnqB9+/b4+PgwatQoXnvtNWe9JEEQBKGUkW9LDCf0Wg4NVbMpx483YDBoASNkBFXuvRfO9X4CnfvQYhmLiLISxuLFi/NcX65cOT755BM++eSTXLepWbMmmzdvLuyhCYIgCAKQWVPmpnVz+fQlqKL+MI8fmPtEW0hQ9dVt2sDGjdD8fynFNg4RZYIgCIIgFCqm9KWb1q1EWGJcvQqLX+kICeUBaNoUfvpJFfsXJ2IeKwiCIAhCoWISYm5aN5e3xIiOhm7dICqsPAB16xrZuhUqVSr+sYg
"text/plain": [
"<Figure size 700x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"plt.figure(figsize=(7,3))\n",
"for style, width, degree in ((\"g-\", 1, 300), (\"b--\", 2, 2), (\"r-+\", 2, 1)):\n",
" polybig_features = PolynomialFeatures(degree=degree, include_bias=False)\n",
" std_scaler = StandardScaler()\n",
" lin_reg = LinearRegression()\n",
" polynomial_regression = Pipeline(\n",
" ((\"poly_features\", polybig_features), (\"std_scaler\", std_scaler), (\"lin_reg\", lin_reg)) )\n",
" polynomial_regression.fit(X, y)\n",
" y_newbig = polynomial_regression.predict(X_new)\n",
" plt.plot(X_new, y_newbig, style, label=str(degree), linewidth=width)\n",
"\n",
"plt.plot(X, y, \"b.\", linewidth=3)\n",
"plt.legend(loc=\"upper left\")\n",
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.axis([-3, 3, 0, 10]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Training data is underfitted for degree 1, fitted well for degree 2, and overfitted for degree 300."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Learning curves provide another way to determine whether model underfitted or overfitted."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider performance on training and validation set *as size of the training set increases*."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Plotting learning curves"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.938531Z",
"iopub.status.busy": "2024-01-10T00:19:45.937877Z",
"iopub.status.idle": "2024-01-10T00:19:45.944402Z",
"shell.execute_reply": "2024-01-10T00:19:45.943841Z"
}
},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def plot_learning_curves(model, X, y):\n",
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=10)\n",
" train_errors, val_errors = [], []\n",
" for m in range(1, len(X_train)):\n",
" model.fit(X_train[:m], y_train[:m])\n",
" y_train_predict = model.predict(X_train[:m])\n",
" y_val_predict = model.predict(X_val)\n",
" train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))\n",
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
"\n",
" plt.figure(figsize=(8,4))\n",
" plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"Training set\")\n",
" plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"Validation set\")\n",
" plt.legend(loc=\"upper right\", fontsize=14) \n",
" plt.xlabel(\"Training set size\", fontsize=14) \n",
" plt.ylabel(\"RMSE\", fontsize=14) "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Underfitted learning curves"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Learning curve for linear model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:45.947504Z",
"iopub.status.busy": "2024-01-10T00:19:45.946932Z",
"iopub.status.idle": "2024-01-10T00:19:46.221925Z",
"shell.execute_reply": "2024-01-10T00:19:46.221338Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAscAAAGDCAYAAADH173JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB3QUlEQVR4nO3dd3iT1dsH8G+60pbS1tICLaNlI3sv2SBTKKMURYECylBABBFBsQwBcQAiouiLgANRlsj8MQRERtkKsrEUaEvLagp0t+f945A0aZI2SdMmbb+f63qupOcZOTnNuHOe+5xHIYQQICIiIiIiONi6AkRERERE9oLBMRERERHRUwyOiYiIiIieYnBMRERERPQUg2MiIiIioqcYHBMRERERPcXgmIiIiIjoKQbHRERERERPMTgmIiIiInqKwTERERER0VN2Fxz/+++/GDRoEKpWrQp3d3f4+vqiffv22Lp1q0n7JyQkYPTo0fDz80OpUqXQqVMnnD59uoBrTURERETFgZOtK5BTVFQUHj16hOHDhyMgIABJSUnYuHEj+vbtixUrVmD06NFG983KykLv3r3x999/Y+rUqfD19cXy5cvRsWNHnDp1CjVq1CjEZ0JERERERY1CCCFsXYm8ZGZmomnTpkhJScGlS5eMbvfrr79i8ODBWL9+PUJCQgAAd+/eRc2aNdGzZ0+sXbu2sKpMREREREWQ3aVVGOLo6IhKlSohISEh1+02bNiAcuXKYcCAAZoyPz8/hIaGYsuWLUhNTS3gmhIRERFRUWa3wfGTJ09w7949XL9+HYsXL8bOnTvRpUuXXPc5c+YMmjRpAgcH3afVokULJCUl4cqVKwVZZSIiIiIq4uwu51htypQpWLFiBQDAwcEBAwYMwLJly3LdJzY2Fu3bt9cr9/f3BwDExMSgfv36BvdNTU3V6VnOysrCgwcPUKZMGSgUCkufBhEREREVECEEHj16hICAAL3OUUvZbXA8adIkhISEICYmBr/++isyMzORlpaW6z7JyclQKpV65a6urpr1xixYsACzZ8/OX6WJiIiIqNDdunULFStWtMqx7DY4rl27NmrXrg0AGDZsGLp164Y+ffogIiLCaE+um5ubwbzilJQUzXpjpk+fjsmTJ2v+VqlUqFy5Mm7dugVPT8/8PBUiIiIiKgCJiYmoVKkSSpcubbVj2m1wnFNISAjGjBmDK1euoFatWga38ff3R2xsrF65uiwgIMDo8ZVKpcFeZ09PTwbHRERERHbMmimwdjsgLyd1SoRKpTK6TaNGjXD69GlkZWXplEdERMDd3R01a9Ys0DoSERERUdFmd8FxfHy8Xll6ejq+//57uLm5oU6dOgBkb/ClS5eQnp6u2S4kJARxcXHYtGmTpuzevXtYv349+vTpY7BnmIiIiIhIze7SKsaMGYPExES0b98eFSpUwJ07d/DTTz/h0qVL+Oyzz+Dh4QFA5givWbMGkZGRCAoKAiCD41atWmHEiBG4cOGC5gp5mZmZHGxHRERERHmyu+B48ODBWLlyJb766ivcv38fpUuXRtOmTbFw4UL07ds3130dHR2xY8cOTJ06FUuXLkVycjKaN2+O1atXG81TJiIiIiJSKxKXj7aFxMREeHl5QaVScUAeERERkR0qiHjN7nqOiYiISorMzEydsTNEJZ2joyOcnZ1tWgcGx0RERIVMCIE7d+5ApVKBJ3CJdCmVSvj6+trszD2DYyIiokKmUqmQkJAAPz8/lCpVyqpztBIVVUIIpKenQ6VSITo6GgBsEiAzOCYiIipEQgjEx8fD09MTvr6+tq4OkV1xc3ND6dKlcfv2bdy7d88mwbHdzXNMRERUnGVmZiIzM5ODvYmMUCgU8PLyQmpqqk1y8hkcExERFaKMjAwAgJMTT94SGaMelJeZmVnoj83gmIiIyAaYZ0xknC3fHwyOiYiIiIieYnBMRERERPQUg2MiIiIioqcYHBMREVGxp1Ao0LFjx3wd48CBA1AoFJg1a5ZV6kT2iUNliYiIqFCYO8iKVw+0naCgIADAjRs3bFoPW2BwTERERIUiPDxcr2zJkiVQqVQG11nTxYsX4e7unq9jtGjRAhcvXuTFW4o5heDPMoMSExPh5eUFlUrFidqJiMhqUlJSEBkZiSpVqsDV1bVgHyw2FlixAhgzBvD3L9jHslBQUBCioqLYS2xnbN1zbOr7pCDiNeYcExERFVexscDs2fK2CLlx4wYUCgXCwsJw8eJF9O/fH2XKlIFCodAEa5s3b8ZLL72E6tWrw93dHV5eXmjXrh02btxo8JiGco7DwsKgUCgQGRmJpUuXonbt2lAqlQgMDMTs2bORlZWls72xnOOgoCAEBQXh8ePHePPNNxEQEAClUokGDRpgw4YNRp/j4MGD4ePjAw8PD3To0AF//vknZs2aBYVCgQMHDpjUVqdPn0ZISAgqV64MpVIJPz8/NG/eHPPmzdPbNj4+Hm+99RaqV68OpVIJX19fDBw4EOfPn9epl0KhQFRUFKKioqBQKDRLScm1ZloFERER2aVr166hVatWqF+/PsLCwnD//n24uLgAAKZPnw4XFxe0bdsW/v7+uHv3Ln7//XeEhIRg6dKlmDBhgsmPM3XqVBw8eBAvvPACunfvjt9++w2zZs1CWlqawSDTkPT0dHTr1g0PHz7EwIEDkZSUhHXr1iE0NBS7du1Ct27dNNtGR0ejTZs2iI2NRY8ePdC4cWNcvnwZzz//PDp37mxyvc+ePYs2bdrA0dERwcHBCAwMREJCAi5cuIBvvvkG7733nmbb69evo2PHjrh9+za6deuGfv36IT4+Hhs3bsT//vc/7Nu3Dy1btoS3tzfCw8OxZMkSAMCkSZM0x8jvgMYiQ5BBKpVKABAqlcrWVSEiomIkOTlZXLhwQSQnJ+uvbNpUiAoV8reULy+En59cvL2FAOStuqx8+fwdv2lTq7ZHYGCgyBmOREZGCgACgPjggw8M7nf9+nW9skePHon69esLLy8v8eTJE511AESHDh10yoYPHy4AiCpVqoiYmBhN+d27d4W3t7coXbq0SE1N1ZTv379fABDh4eEGn0NwcLDO9nv37hUARPfu3XW2f+WVVwQAMW/ePJ3ylStXap73/v37DT5vbZMnTxYAxG+//aa37t69ezp/t2nTRjg6Oopdu3bplF++fFmULl1a1K9fX+85BQYG5lmHgpLr+0RLQcRr7DkmIiKyF3fuANHR1j9uQoL1j1kIypcvr9P7qa1q1ap6ZR4eHggLC8OUKVNw4sQJdOjQwaTHmTlzJvy1crJ9fX0RHByMNWvW4PLly6hfv75Jx1m8eLGmZxsAunTpgsDAQJw4cUJTlpqaivXr16Ns2bKYMmWKzv4jRozAxx9/jMuXL5v0eGpubm56ZWXKlNHcP3PmDI4cOYKRI0eie/fuOtvVrFkTr732GhYtWoTz58+jXr16Zj12ccTgmIiIyF6UL5//Y2RmygUA0tNlYOztDTg7yzJHR7lYyhp1NFHDhg11gk1t8fHx+Oijj7Bz505ERUUhOTlZZ31MTIzJj9O0aVO9sooVKwIAEkz8YeHt7Y0qVaoYPM7Ro0c1f1++fBmpqalo1qwZlEqlzrYKhQJt2rQxOTgODQ3FkiVL0L9/fwwePBjPP/882rdvjwoVKuhsd+zYMQBAXFycwbzhS5cuaW4ZHDM4JiIish8nT1r3eKdPA02bAvv2AU2aWPfYhaBcuXIGyx88eIDmzZvj5s2beO6559C1a1d4e3vD0dERZ8+exZYtW5Cammry4xia5cDJSYZImeofGnnw8vIyWO7k5KQzsC8xMREAULZsWYPbG3vOhrRs2RIHDhzA/PnzsXbtWqxatQoA0Lx5cyxcuBCdOnUCINsLALZv347t27cbPd6TJ09MfuzijMExERER2SVjFw1ZuXIlbt68iblz5+L999/XWffRRx9hy5YthVE9i6gD8fj4eIPr4+LizDpeu3btsHPnTiQnJyMiIgJbt27F8uXL0bt3b5w/fx5Vq1bVPOYXX3yB8ePH5+8JlACcyo2IiKi48vcHwsPtdo5jS12/fh0AEBwcrLfu0KF
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"lin_reg = LinearRegression()\n",
"plot_learning_curves(lin_reg, X, y)\n",
"plt.axis([0, 80, 0, 3]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- RMSE on training set small for small training set size since model can fit data well for very few data points (perfect for one or two points).\n",
"\n",
"- RMSE performance on training and validation eventually similar but high since linear model cannot fit the data well (recall data generated by quadratic)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Overfitted learning curves "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Learning curve for poynomial model of degree 10"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:46.225436Z",
"iopub.status.busy": "2024-01-10T00:19:46.224761Z",
"iopub.status.idle": "2024-01-10T00:19:46.559315Z",
"shell.execute_reply": "2024-01-10T00:19:46.558734Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAscAAAGDCAYAAADH173JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB1HklEQVR4nO3deVzT5QMH8M+4Bsglggoq4G3e95UXah55K2qZZ5ZlaZmmZllemfmrtMw0LfOo1PK+zTQlMyUV7wM8EBUm4MFQueH7++NxG2MbbDDYBp/36/V9bXu+x559Rf3w7DlkkiRJICIiIiIi2Fm6AkRERERE1oLhmIiIiIjoGYZjIiIiIqJnGI6JiIiIiJ5hOCYiIiIieobhmIiIiIjoGYZjIiIiIqJnGI6JiIiIiJ5hOCYiIiIieobhmIiIiIjoGasLx5cuXcLgwYNRrVo1uLq6wsfHBx06dMCuXbuMOj8xMRHjxo2Dr68vypQpg+DgYISHhxdxrYmIiIioJHCwdAVyi46OxuPHjzFq1Cj4+/sjOTkZW7ZsQd++fbFixQqMGzfO4LnZ2dno1asXzp07h6lTp8LHxwfLli1Dp06dcPr0adSsWbMYPwkRERER2RqZJEmSpSuRn6ysLDRr1gypqam4evWqweN+//13DB06FJs2bUJISAgAICEhAbVq1ULPnj2xfv364qoyEREREdkgq+tWoY+9vT2qVKmCxMTEPI/bvHkzKlSogIEDB6rLfH19MWTIEOzYsQNpaWlFXFMiIiIismVWG46fPn2K+/fv48aNG1i8eDH27duHLl265HnOmTNn0LRpU9jZaX+sli1bIjk5GZGRkUVZZSIiIiKycVbX51hlypQpWLFiBQDAzs4OAwcOxNKlS/M8R6FQoEOHDjrlfn5+AIDY2Fg0aNBA77lpaWlaLcvZ2dl4+PAhypUrB5lMVtCPQURERERFRJIkPH78GP7+/jqNowVlteF40qRJCAkJQWxsLH7//XdkZWUhPT09z3NSUlIgl8t1yp2dndX7DVmwYAHmzJlTuEoTERERUbG7c+cOKleubJZrWW04rlOnDurUqQMAGDlyJLp164Y+ffogLCzMYEuui4uL3n7Fqamp6v2GzJgxA5MnT1a/ViqVCAgIwJ07d+Dh4VGYj6IWEQG0bKld9ugRYKZfdEqlqCigcWPtsnv3gDz+qImIiKiESEpKQpUqVeDu7m62a1ptOM4tJCQEb7zxBiIjI1G7dm29x/j5+UGhUOiUq8r8/f0NXl8ul+ttdfbw8DBbOHZz0y3z8GA4Lgx9fzQeHgzHREREpYk5u8DaTCxTdYlQKpUGj2ncuDHCw8ORnZ2tVR4WFgZXV1fUqlWrSOtI1sH6JyckIiIia2V14Tg+Pl6nLCMjA+vWrYOLiwvq1q0LQLQGX716FRkZGerjQkJCEBcXh61bt6rL7t+/j02bNqFPnz56W4bJtnGsJBEREZmT1XWreOONN5CUlIQOHTqgUqVKuHfvHn799VdcvXoVX331Fdye9U2YMWMG1q5di6ioKAQFBQEQ4bh169YYM2YMLl++rF4hLysri4PtShG2HBMREVFBWV04Hjp0KFatWoXly5fjwYMHcHd3R7NmzbBw4UL07ds3z3Pt7e2xd+9eTJ06FUuWLEFKSgpatGiBNWvWGOynTLaNLcdERERkTjaxfLQlJCUlwdPTE0ql0mwD8i5fBurV0y7LzmbAK4zoaODZFwdqjx/rH/xIREREJUtR5DWrazkmMgV/sSAiW5aVlaU1doaotLO3t4ejo6NF68BwTCUOvwshImsnSRLu3bsHpVIJfoFLpE0ul8PHx8dsLcGmYjgmm6av5Zj/zxCRtVMqlUhMTISvry/KlClj1jlaiWyVJEnIyMiAUqlETEwMAFgkIDMck03j/ydEZGskSUJ8fDw8PDzg4+Nj6eoQWRUXFxe4u7vj7t27uH//vkXCsdXNc0xUWGw5JiJrlpWVhaysLIt9ZUxk7WQyGTw9PZGWlmaRPvkMxxbGls/C4f0jIluTmZkJAHBw4Je3RIaoBuVlZWUV+3szHBcjtmgWD95nIrIF7GdMZJgl/34wHJNN4/8tREREZE4Mx1TisOWYiIiICorhmGwaW46JiIjInBiOqcRhyzEREeUmk8nQqVOnQl3jyJEjkMlkmD17tlnqRNaJ4dhESiUwfDhQuzbw4YeABQZRUg5cBISIyHbIZDKTNrKcoKAgBAUFWboaFsF5ZEy0aBHw66/i+YIFQNu2QO/elq0TERGRLZg1a5ZO2ddffw2lUql3nzlduXIFrq6uhbpGy5YtceXKFS7eUsIxHJto/37t10ePGh+O2aJpfmw5JiLKg0IBrFgBvPEG4Odn6dro7Y6wZs0aKJXKIu+qUKdOnUJfw9XV1SzXIevGbhUmyMwEzp/XLrt71zJ1IYHfuhER5UGhAObMEY825NatW5DJZBg9ejSuXLmCAQMGoFy5cpDJZLh16xYAYNu2bXj55ZdRo0YNuLq6wtPTE+3bt8eWLVv0XlNfn+PRo0dDJpMhKioKS5YsQZ06dSCXyxEYGIg5c+YgOztb63hDfY5VXRCePHmCd999F/7+/pDL5WjYsCE2b95s8DMOHToU3t7ecHNzQ8eOHfH3339j9uzZkMlkOHLkiFH3Kjw8HCEhIQgICIBcLoevry9atGiB+fPn6xwbHx+P9957DzVq1IBcLoePjw8GDRqEixcvatVLJpMhOjoa0dHRWt1cSktfa7Ycm+DaNSA1VbuM4dj6sOWYiKhkuH79Olq3bo0GDRpg9OjRePDgAZycnAAAM2bMgJOTE9q1awc/Pz8kJCRg586dCAkJwZIlSzBx4kSj32fq1KkIDQ1F79690b17d2zfvh2zZ89Genq63pCpT0ZGBrp164ZHjx5h0KBBSE5OxsaNGzFkyBDs378f3bp1Ux8bExODtm3bQqFQoEePHmjSpAkiIiLwwgsvoHPnzkbX++zZs2jbti3s7e3Rr18/BAYGIjExEZcvX8bKlSvx0UcfqY+9ceMGOnXqhLt376Jbt27o378/4uPjsWXLFvzxxx84dOgQWrVqBS8vL8yaNQtff/01AGDSpEnqaxR2QKOtYDg2wblzumUMx5bFlmMiKlGaNwfu3SvcNbKyNKPFMzLEY5cuwLPleGFvL7aCqlgROHWqcHU00rFjx/DJJ59gzpw5Ovv27t2LatWqaZU9efIEbdu2xccff4yxY8ca3cc4PDwc58+fh9+zricff/wxatasiW+//RazZs1SB/K8xMbGokWLFjhy5Ij6+GHDhqFr165YtGiRVjj+4IMPoFAoMH/+fHz44Yfq8p9++gljx441qs4A8PPPPyMtLQ3bt29Hv379tPY9ePBA6/XIkSOhUCiwf/9+dO/eXV0+c+ZMNG/eHK+//jrOnz8PLy8vzJ49G2vWrAGgvytMScdwbIKzZ3XL7t4VLZUFCWkMdkWDLcdEZLPu3QNiYsx/3cRE81+zGFSsWFGr9TOn3MEYANzc3DB69GhMmTIFJ0+eRMeOHY16n48//lgdjAHAx8cH/fr1w9q1axEREYEGDRoYdZ3FixdrBekuXbogMDAQJ0+eVJelpaVh06ZNKF++PKZMmaJ1/pgxY/C///0PERERRr2fiouLi05ZuXLl1M/PnDmDf//9F6+++qpWMAaAWrVq4fXXX8eiRYtw8eJF1K9f36T3LokYjk2gLxynpwP37wO+vsVeHQJ/wSCiEqZixcJfI3fLcWIi4OVl3pbjYtKoUSODrbbx8fH4/PPPsW/fPkRHRyMlJUVrf2xsrNHv06xZM52yypUrAwASjfzFwsvLC1WrVtV7nePHj6tfR0REIC0tDc2bN4dcLtc6ViaToW3btkaH4yFDhuDrr7/GgAEDMHToULzwwgvo0KEDKlWqpHXciRMnAABxcXF6W4KvXr2qfmQ4Zjg2ib5wDIjWY4Zj68GWYyKyWeburhAeDjRrBhw6BDRtat5rF4MKFSroLX/48CFatGiB27dv4/nnn0fXrl3h5eUFe3t7nD17Fjt27EBaWprR7+Ph4aFT5uAgIlKWkQsaeHp66i13cHD
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.pipeline import Pipeline\n",
"polynomial_regression = Pipeline((\n",
" (\"poly_features\", PolynomialFeatures(degree=10, include_bias=False)),\n",
" (\"lin_reg\", LinearRegression()),\n",
" ))\n",
"plot_learning_curves(polynomial_regression, X, y)\n",
"plt.axis([0, 80, 0, 3]); "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- RMSE now much smaller.\n",
"- But training and validation set errors remain quite different.\n",
"\n",
"Model performs much better on training set than validation set, suggesting overfitting."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Well-fitted learning curves "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Learning curve for poynomial model of degree 2"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:46.562771Z",
"iopub.status.busy": "2024-01-10T00:19:46.562122Z",
"iopub.status.idle": "2024-01-10T00:19:46.882864Z",
"shell.execute_reply": "2024-01-10T00:19:46.882258Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAscAAAGDCAYAAADH173JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABxCklEQVR4nO3deXyM1/4H8M9km2ySyIKELAjV2ndSYr1C1VINuiGq1fqVVqm62qrtqrqtpapUexXtLWrfudaoEiFCS+1EkEQWZILsyfn9ccxMJjOTPZlJ8nm/Xs9rZs6znTmZzHznzPecRyGEECAiIiIiIliYugJEREREROaCwTERERER0VMMjomIiIiInmJwTERERET0FINjIiIiIqKnGBwTERERET3F4JiIiIiI6CkGx0RERERETzE4JiIiIiJ6isExEREREdFTZhcc//333xg6dCgaNGgAe3t7uLu7IzAwEDt37izS/snJyRg7diw8PDzg4OCAHj16IDIyspxrTURERERVgZWpK5BfdHQ0Hj16hFGjRsHLywupqanYvHkzBg4ciBUrVmDs2LFG983NzUX//v3x559/YsqUKXB3d8eyZcvQvXt3nDlzBo0aNarAZ0JERERElY1CCCFMXYnC5OTkoG3btkhPT8fly5eNbrdhwwYMHz4cGzduRHBwMAAgMTERjRs3Rr9+/bB27dqKqjIRERERVUJml1ZhiKWlJby9vZGcnFzgdps2bULt2rUxZMgQTZmHhweGDRuG7du3IyMjo5xrSkRERESVmdkGx0+ePEFSUhJu3LiBRYsWYe/evejVq1eB+5w9exZt2rSBhYXu0+rQoQNSU1Nx9erV8qwyEREREVVyZpdzrDZ58mSsWLECAGBhYYEhQ4Zg6dKlBe4TFxeHwMBAvXJPT08AQGxsLJo3b25w34yMDJ2e5dzcXDx48ABubm5QKBQlfRpEREREVE6EEHj06BG8vLz0OkdLymyD44kTJyI4OBixsbHYsGEDcnJykJmZWeA+aWlpUCqVeuW2traa9cbMmzcPs2bNKl2liYiIiKjC3blzB/Xq1SuTY5ltcNykSRM0adIEADBy5Ej06dMHAwYMQHh4uNGeXDs7O4N5xenp6Zr1xkybNg2TJk3SPFapVPDx8cGdO3fg5ORUmqdCREREROUgJSUF3t7eqFGjRpkd02yD4/yCg4Pxzjvv4OrVq3jmmWcMbuPp6Ym4uDi9cnWZl5eX0eMrlUqDvc5OTk4MjomIiIjMWFmmwJrtgLz81CkRKpXK6DatWrVCZGQkcnNzdcrDw8Nhb2+Pxo0bl2sdiYiIiKhyM7vgOCEhQa8sKysLP//8M+zs7PDcc88BkL3Bly9fRlZWlma74OBgxMfHY8uWLZqypKQkbNy4EQMGDDDYM0xEREREpGZ2aRXvvPMOUlJSEBgYiLp16+LevXv49ddfcfnyZSxYsACOjo4AZI7wmjVrEBUVBT8/PwAyOO7UqRNGjx6Nixcvaq6Ql5OTw8F2RERERFQoswuOhw8fjpUrV2L58uW4f/8+atSogbZt22L+/PkYOHBggftaWlpiz549mDJlCpYsWYK0tDS0b98eq1evNpqnTERERESkVikuH20KKSkpcHZ2hkql4oA8IiIiIjNUHvGa2fUcExERVRc5OTk6Y2eIqjtLS0tYW1ubtA4MjomIiCqYEAL37t2DSqUCf8Al0qVUKuHu7m6yX+4ZHBMREVUwlUqF5ORkeHh4wMHBoUznaCWqrIQQyMrKgkqlQkxMDACYJEBmcExERFSBhBBISEiAk5MT3N3dTV0dIrNiZ2eHGjVq4O7du0hKSjJJcGx28xwTERFVZTk5OcjJyeFgbyIjFAoFnJ2dkZGRYZKcfAbHREREFSg7OxsAYGXFH2+JjFEPysvJyanwczM4JiIiMgHmGRMZZ8r/DwbHRERERERPMTgmIiIiInqKwTERERER0VMMjomIiKjKUygU6N69e6mOERoaCoVCgZkzZ5ZJncg8cagsERERVYjiDrLi1QNNx8/PDwBw69Ytk9bDFBgcExERUYWYMWOGXtnixYuhUqkMritLly5dgr29famO0aFDB1y6dIkXb6niFIJfywxKSUmBs7MzVCoVJ2onIqIyk56ejqioKNSvXx+2trble7K4OGDFCuCddwBPz/I9Vwn5+fkhOjqavcRmxtQ9x0X9PymPeI05x0RERFVVXBwwa5a8rURu3boFhUKBkJAQXLp0CS+99BLc3NygUCg0wdrWrVvx6quvwt/fH/b29nB2dkbXrl2xefNmg8c0lHMcEhIChUKBqKgoLFmyBE2aNIFSqYSvry9mzZqF3Nxcne2N5Rz7+fnBz88Pjx8/xgcffAAvLy8olUq0aNECmzZtMvochw8fDldXVzg6OqJbt274/fffMXPmTCgUCoSGhhaprSIjIxEcHAwfHx8olUp4eHigffv2mDt3rt62CQkJ+PDDD+Hv7w+lUgl3d3e8/PLLuHDhgk69FAoFoqOjER0dDYVCoVmqS6410yqIiIjILF2/fh2dOnVC8+bNERISgvv378PGxgYAMG3aNNjY2KBLly7w9PREYmIiduzYgeDgYCxZsgQTJkwo8nmmTJmCo0eP4sUXX0RQUBC2bduGmTNnIjMz02CQaUhWVhb69OmDhw8f4uWXX0ZqairWr1+PYcOGYd++fejTp49m25iYGAQEBCAuLg59+/ZF69atceXKFfzjH/9Az549i1zvc+fOISAgAJaWlhg0aBB8fX2RnJyMixcv4ocffsCnn36q2fbGjRvo3r077t69iz59+mDw4MFISEjA5s2b8b///Q+HDh1Cx44d4eLighkzZmDx4sUAgIkTJ2qOUdoBjZWGIINUKpUAIFQqlamrQkREVUhaWpq4ePGiSEtL01/Ztq0QdeuWbqlTRwgPD7m4uAgByFt1WZ06pTt+27Zl2h6+vr4ifzgSFRUlAAgA4vPPPze4340bN/TKHj16JJo3by6cnZ3FkydPdNYBEN26ddMpGzVqlAAg6tevL2JjYzXliYmJwsXFRdSoUUNkZGRoyo8cOSIAiBkzZhh8DoMGDdLZ/uDBgwKACAoK0tn+jTfeEADE3LlzdcpXrlyped5Hjhwx+LzzmjRpkgAgtm3bprcuKSlJ53FAQICwtLQU+/bt0ym/cuWKqFGjhmjevLnec/L19S20DuWlwP+TPMojXmPPMRERkbm4dw+IiSn74yYnl/0xK0CdOnV0ej/zatCggV6Zo6MjQkJCMHnyZJw+fRrdunUr0nmmT58Ozzw52e7u7hg0aBDWrFmDK1euoHnz5kU6zqJFizQ92wDQq1cv+Pr64vTp05qyjIwMbNy4EbVq1cLkyZN19h89ejT+/e9/48qVK0U6n5qdnZ1emZubm+b+2bNnceLECbz55psICgrS2a5x48Z4++23sXDhQly4cAHNmjUr1rmrIgbHRERE5qJOndIfIydHLgCQlSUDYxcXwNpalllayqWkyqKORdSyZUudYDOvhIQEfPnll9i7dy+io6ORlpamsz42NrbI52nbtq1eWb169QAAyUX8YuHi4oL69esbPE5YWJjm8ZUrV5CRkYF27dpBqVTqbKtQKBAQEFDk4HjYsGFYvHgxXnrpJQwfPhz/+Mc/EBgYiLp16+psd/LkSQBAfHy8wbzhy5cva24ZHDM4JiIiMh8REWV7vMhIoG1b4NAhoE2bsj12Bahdu7bB8gcPHqB9+/a4ffs2nn/+efTu3RsuLi6wtLTEuXPnsH37dmRkZBT5PIZmObCykiFSjvqLRiGcnZ0NlltZWekM7EtJSQEA1KpVy+D2xp6zIR07dkRoaCi++OILrF27FqtWrQIAtG/fHvPnz0ePHj0AyPYCgN27d2P37t1Gj/fkyZMin7sqY3BMREREZsnYRUNWrlyJ27dvY86cOfjss8901n355ZfYvn17RVSvRNSBeEJCgsH18fHxxTpe165dsXfvXqSlpSE8PBw7d+7EsmXL0L9/f1y4cAENGjTQnPPbb7/F+PHjS/cEqgFO5UZERFRVeXoCM2aY7RzHJXXjxg0AwKBBg/TWHTt2rKK
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.pipeline import Pipeline\n",
"polynomial_regression = Pipeline((\n",
" (\"poly_features\", PolynomialFeatures(degree=2, include_bias=False)),\n",
" (\"lin_reg\", LinearRegression()),\n",
" ))\n",
"plot_learning_curves(polynomial_regression, X, y)\n",
"plt.axis([0, 80, 0, 3]); "
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Bias-variance tradeoff"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Bias-variance tradeoff refers to the problem of simultaneously reducing the two types of errors that prevents supervised learning algorithms from generalising to other data.\n",
"\n",
"- **Bias**: Expected difference between data and prediction.\n",
"\n",
"- **Variance**: Expected ability of the model to fluctuate."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"One seeks a model that accurately fits the training data, while also generalising to unseen data.\n",
"\n",
"Typically impossible to do both simultaneously.\n",
"\n",
"- On one hand, high-variance models may fit training data well but typically overfit to noise or unrepresentative training data.\n",
"\n",
"- On the other hand, low-complexity models with a high bias typically underfit training data.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Contributions to the mean square error "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Consider underlying (true) model:\n",
"$ y = f(x) + \\epsilon,$\n",
"\n",
"where\n",
"- $y$ is the target and $x$ features.\n",
"- $f$ is the true model, that we will approximate by $h$.\n",
"- $\\epsilon$ is the noise, with zero mean and variance $\\sigma^2$.\n",
"\n",
"\n",
"Approximate $f$ by $h$, which is fitted by a learning algorithm and training data."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Expected value of the mean square error is given by\n",
"\n",
"$$\\text{E} \\left[ \\left(y - h(x)\\right)^2 \\right]\n",
"= \n",
"\\text{Bias}^2\\left[ h(x) \\right]\n",
"+\n",
"\\text{Var}\\left[ h(x) \\right]\n",
"+\n",
"\\sigma^2$$\n",
"\n",
"Three contributions to the error:\n",
" \n",
"1. Bias: $\\text{Bias}\\left[ h(x) \\right] = \\text{E} \\left[ h(x) - f(x) \\right].$\n",
"\n",
"2. Variance: $\\text{Var}\\left[ h(x) \\right] = \\text{E} \\left[ h^2(x)\\right] - \\text{E} \\left[ h(x)\\right]^2. $\n",
"\n",
"3. Irreducible error $\\sigma^2$ due to noise in observations."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Tradeoff \n",
"\n",
"By choosing a complex model, the bias can be made small but the variance will be large.\n",
"\n",
"By choosing a simple model, the variance can be made small but the bias will be large.\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture07_Images/biasvariance.png\" width=\"700px\" style=\"display:block; margin:auto\"/> \n",
"\n",
"[Image source](http://francescopochetti.com/bias-v-s-variance-tradeoff/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Regularization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One approach to mitigate the bias-variance tradeoff is by *regularization*.\n",
"\n",
"Consider a complex model but place additional constraints to reduce its variance.\n",
"\n",
"As a consequence the bias is increased but can introduce a regularisation parameter to control the tradeoff."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Add regularization term $R(\\theta)$ to the cost function:\n",
"\n",
"$$C_\\lambda(\\theta) = C(\\theta) + \\lambda\\ R(\\theta).$$\n",
"\n",
"The regularization parameter $\\lambda$ controls the amount of regularization."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Regularization should only be added when training. When using fitted model to make predictions, should evaluate cost without regularization term."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Tikhonov regularization \n",
"\n",
"*Tikhonov* regularization adopts $\\ell_2$ regularising term (also called *Ridge regression*):\n",
"\n",
"$$ R(\\theta) = \\frac{1}{2} \\sum_{j=1}^n \\theta_j^2 = \\frac{1}{2} \\theta^{\\rm T}\\theta.$$\n",
"\n",
"\n",
"\n",
"Acts to keep parameters small.\n",
"\n",
"Note that the bias term $\\theta_0$ is not regularized (i.e. sum starts from 1 not 0)."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:46.887064Z",
"iopub.status.busy": "2024-01-10T00:19:46.886396Z",
"iopub.status.idle": "2024-01-10T00:19:46.893828Z",
"shell.execute_reply": "2024-01-10T00:19:46.893284Z"
},
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"from sklearn.linear_model import Ridge\n",
"\n",
"np.random.seed(42)\n",
"m = 20\n",
"X = 3 * np.random.rand(m, 1)\n",
"y = 1 + 0.5 * X + np.random.randn(m, 1) / 1.5\n",
"X_new = np.linspace(0, 3, 100).reshape(100, 1)\n",
"\n",
"def plot_model(model_class, polynomial, alphas, **model_kargs): \n",
" # Use alpha for regularization parameter (lambda used already)\n",
" for alpha, style in zip(alphas, (\"b-\", \"g--\", \"r:\")):\n",
" model = model_class(alpha, **model_kargs) if alpha > 0 else LinearRegression()\n",
" if polynomial:\n",
" model = Pipeline((\n",
" (\"poly_features\", PolynomialFeatures(degree=10, include_bias=False)),\n",
" (\"std_scaler\", StandardScaler()),\n",
" (\"regul_reg\", model),\n",
" ))\n",
" model.fit(X, y)\n",
" y_new_regul = model.predict(X_new)\n",
" lw = 2 if alpha > 0 else 1\n",
" plt.plot(X_new, y_new_regul, style, linewidth=lw, label=r\"$\\lambda = {}$\".format(alpha))\n",
" plt.plot(X, y, \"b.\", linewidth=3)\n",
" plt.legend(loc=\"upper left\", fontsize=15)\n",
" plt.xlabel(\"$x_1$\", fontsize=18)\n",
" plt.axis([0, 3, 0, 4])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:46.896773Z",
"iopub.status.busy": "2024-01-10T00:19:46.896231Z",
"iopub.status.idle": "2024-01-10T00:19:47.338341Z",
"shell.execute_reply": "2024-01-10T00:19:47.337595Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1oAAAHVCAYAAADo24q6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC0AUlEQVR4nOzdeVhU1RsH8O8w7LuICgoIKrgv4L4vmaamlZm2WGqaVi6VZqVZmlYupS2aqWXaYv1Kc99Sy8w0TVlVJDdAEFBQGZCdmfv748jAwLAMzArfz/PwAHeZe0C877z3nPMemSRJEoiIiIiIiEhvrEzdACIiIiIiotqGiRYREREREZGeMdEiIiIiIiLSMyZaREREREREesZEi4iIiIiISM+YaBEREREREekZEy0iIiIiIiI9Y6JFRERERESkZ0y0iIiIiIiI9IyJFhERERERkZ6ZLNH64IMPIJPJ0K5duyodf+PGDYwdOxbu7u5wdXXFI488gmvXrhm4lUREVJcwNhERkb7IJEmSjH3RxMREtGzZEjKZDP7+/jh//nyFx9+7dw8hISFQKBSYM2cObGxs8Mknn0CSJERERKB+/fpGajkREdVWjE1ERKRP1qa46Ouvv44ePXpAqVQiLS2t0uPXrl2Ly5cv499//0XXrl0BAMOGDUO7du2wcuVKfPjhh4ZuMhER1XKMTUREpE9G79H666+/MGjQIISHh2PmzJlIS0ur9Klht27dAAD//vuvxvahQ4fi6tWruHLlisHaS0REtR9jExER6ZtR52gplUrMnDkTU6ZMQfv27at0jkqlQlRUFLp06VJmX7du3XD16lVkZmbqu6lERFRHMDYREZEhGHXo4Lp16xAfH48jR45U+Zw7d+4gLy8P3t7eZfYVbUtKSkLLli21np+Xl4e8vDz19yqVCnfu3EH9+vUhk8l0/AmIiKi6JElCZmYmGjduDCsr8yl6a+zYxLhERGQ+DBmbjJZo3b59G++++y7eeecdNGjQoMrn5eTkAADs7OzK7LO3t9c4RpulS5fivffe07G1RERkKAkJCfDx8TF1MwCYJjYxLhERmR9DxCajJVoLFiyAh4cHZs6cqdN5Dg4OAKDx9K9Ibm6uxjHazJs3D7Nnz1Z/r1Ao4Ofnh4SEBLi6uurUFiIiqr6MjAz4+vrCxcXF1E1RM0VsYlwiIjIfhoxNRkm0Ll++jA0bNuDTTz9FUlKSentubi4KCgoQFxcHV1dXeHh4lDnXw8MDdnZ2SE5OLrOvaFvjxo3LvbadnZ3WJ46urq4MaEREJmAuw+NMFZsYl4iIzI8hYpNRBsnfuHEDKpUKs2bNQkBAgPrj9OnTuHTpEgICArB48WLtDbSyQvv27XH27Nky+06fPo1mzZqZ1dNRIiKyDIxNRERkSEbp0WrXrh127NhRZvuCBQuQmZmJzz77DM2bNwcAXL9+HdnZ2WjVqpX6uDFjxuCtt97C2bNn1RWe/vvvP/zxxx94/fXXjfEjEBFRLcPYREREhmT0dbRKGjBgQJm1SgYMGIBjx46hZLMyMzMRHByMzMxMvP7667CxscGqVaugVCoRERGh0wTmjIwMuLm5QaFQcIgGEZERWcr919ixyVJ+L0REtZEh78FGLe9eXS4uLvjzzz/x2muv4f3334dKpcKAAQPwySef6JRkERER6QtjExERVcSkPVqmwCeHRESmwfuvdvy9EBGZjiHvweazYiQREREREVEtYRFDB81BQUEBlEqlqZtBtYRcLoeNjY2pm0FEFo6xieoCxkyyVEy0KpGRkYG0tDSti1IS1YSdnR08PT05VIiIdMbYRHUNYyZZIiZaFcjIyMCNGzfg7OwMT09P2NjYmM1Cm2S5JElCQUEBFAoFbty4AQAMHERUZYxNVJcwZpIlY6JVgbS0NDg7O8PHx4dBjPTKwcEBLi4uSExMRFpaGoMGEVUZYxPVNYyZZKlYDKMcBQUFyMvLg5ubGwMZGYRMJoObmxvy8vJQUFBg6uYQkQVgbKK6ijGTLBETrXIUTS7m5EsypKK/L05mJ6KqYGyiuowxkywNE61K8IkhGRL/voioOnjvoLqIf/dkaZhoERERERER6RkTLSIiIiIiIj1jokVERERERKRnTLSoWk6cOAGZTAYXFxc8/PDDSEtLM3WTAAA5OTl49913ERQUBHt7ezRu3BjPP/+8eu0NIiKqnRiXiMjcMNGiapEkCc888wxsbW2xb98+LF682NRNQm5uLgYNGoQlS5bg3r17eOSRR+Dr64tNmzYhODgY165dM3UTiYjIQBiXiMjcMNGiaunTpw9++OEHfPfddwDEk0RTe//993Hq1Cn07NkTly5dws8//4zTp09j5cqVSE1NxfPPP2/qJhIRkYEwLhGRuWGiRTXSr18/yGQyREdHm3Rdi/z8fKxZswYA8MUXX8DZ2Vm9b/bs2ejQoQOOHTuG0NBQUzWRiIiMgHGJiMwFEy2qERcXF/j5+SE3NxdXrlwxWTtOnDgBhUKB5s2bIzg4uMz+MWPGAAD27Nlj7KYREZERMS4RkblgokU1cv78eSQlJQEAoqKiTNaOyMhIAEBISIjW/UXbTdlGIiIyPMYlIjIXTLSo2lQqFaZMmYKCggIAwLlz56p87oABAyCTyXT62Lx5c7mvd/36dQCAj4+P1v1F2+Pj46vcRiIisiyMS0RkTqxN3QBLlZ0NxMSYuhVV16oV4Oio39f87LPPcPr0afj4+CAxMVGngPbQQw/B399fp+u1aNGi3H337t0DADiW80M6OTkBADIzM3W6JhGRJanrsYlxiYh0EZ8eD3ulvcFen4lWNcXEAJ07m7oVVRcaCpQzeqFaYmNj8c4778DLywtbtmxB//79dQpob731lv4aQ0REAOp2bGJcIiJdSJKEtmvbIiszy2DXYKJVTa1aiQBhKVq10u/rTZs2DVlZWfjuu+/Qq1cv2NnZ4dq1a8jKylI/pTOmompO2dnZWvdnZYn/RC4uLkZrExGRsdXl2MS4RES6yMjLQFaB4ZIsgIlWtTk66reHyJJs3rwZhw8fxpgxYzB69GgAQJs2bRAeHo7z58+je/fulb7GsmXLEKPj+JYpU6agT58+Wvf5+fkBABITE7XuL9retGlTna5JRGRJ6mpsYlwiIl0lZmj/v6lPTLRIJzdv3sScOXPg4eGhXh8EADp27Ijw8HCcO3euSgHt4MGDOHbsmE7XHjBgQLkBrWPHjgCAsLAwrfuLtnfo0EGnaxIRkXljXCKi6mCiRWZn5syZuHPnDr7//ns0atRIvb0ooFR1PPyff/6p13b17t0bbm5uuHr1KiIiItCpUyeN/du2bQMAjBw5Uq/XJSIi02JcIqLquJF5w+DXYHl3qrJdu3Zh69atGD58OMaPH6+xT9eApm+2traYMWMGAGD69Onqse8AsGrVKkRFRaF///7obEmzxImIqEKMS0RUXd2bdMeqIaswvet0g11DJkmSZLBXN0MZGRlwc3ODQqGAq6trucfl5uYiNjYWAQEBsLc3XNlHS6FQKNCmTRvcu3cPFy5cKLMuyJ07d1C/fn14enoiNTXVJG3Mzc3FgAEDcPr0aXh7e6Nv376Ij4/H6dOn0aBBA5w6dQrNmjUzSdvKw78zqkuqev+ta3T5vfCeUYxxqe7h3z8ZgiFjE3u0qEreeOMNJCUlYcWKFVoXX/Tw8ICPjw/S0tKQnJxsghYC9vb2OHr0KN555x04Ojpi586diI+Px8SJExEWFsZgRkRUizAuEZG5Y49WOfjUhIyBf2dUl7BHSzv2aBFVDf/+yRDYo0VERERERKQnESkRuHnvJlSSymDXYNVBIiIiIiKqM3IKchC8PhgA0KthL4Ndhz1aRERERERUZ5Qs7e7l5GWw6zDRIiIiIiKiOuNGRnGi5e3sbbDrMNEiIiIiIqI6IzEjUf11E9cmBrsOEy0iIiIiIqozSiZa3i61oEfrwoULeOKJJ9CsWTM4OjrC09MT/fr1w549eyo9d/PmzZDJZFo/UlJSjNB6IiKqbRiXiIjqppJztJq4GK5Hy2hVB+Pj45GZmYkJEyagcePGyM7Oxq+//opRo0Zh/fr
"text/plain": [
"<Figure size 1000x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,5))\n",
"plt.subplot(121)\n",
"plot_model(Ridge, polynomial=False, alphas=(0, 10, 100), random_state=42)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.subplot(122)\n",
"plot_model(Ridge, polynomial=True, alphas=(0, 10**-5, 1), random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercises 2 and 3 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Lasso regularization\n",
"\n",
"\n",
"*Lasso* regularization adopts $\\ell_1$ regularising term:\n",
"\n",
"$$ R(\\theta) =\\sum_{j=1}^n \\left\\vert \\theta_j \\right\\vert .$$\n",
"\n",
"Acts to promote sparsity.\n",
"\n",
"Again, note that the bias term $\\theta_0$ is not regularized (i.e. sum starts from 1 not 0).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:47.341641Z",
"iopub.status.busy": "2024-01-10T00:19:47.341174Z",
"iopub.status.idle": "2024-01-10T00:19:47.722342Z",
"shell.execute_reply": "2024-01-10T00:19:47.721671Z"
},
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1oAAAHVCAYAAADo24q6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACbXUlEQVR4nOzdd3iTVRsG8DsN0D0oBToZhbIpsy17K3uIDAVEFIRPhoOhoiKCCiIKioiAoojgAkFkKaCAyCjQlpZVZhmdtIx0p236fn8cm5I23dm5f9fVK+SdJ23Ik+d9z3mOTJIkCURERERERKQzNsZuABERERERkaVhokVERERERKRjTLSIiIiIiIh0jIkWERERERGRjjHRIiIiIiIi0jEmWkRERERERDrGRIuIiIiIiEjHmGgRERERERHpGBMtIiIiIiIiHWOiRUREREREpGNGS7Q++OADyGQytGrVqlzbx8XFYcyYMXBzc4OLiwuGDx+OGzdu6LmVRERkTRibiIhIV2SSJEmGPmlsbCyaNm0KmUyGBg0a4Pz586Vun56ejvbt20OhUGDOnDmoXr06Vq5cCUmScPbsWdSqVctALSciIkvF2ERERLpUzRgnnTt3Ljp16gSVSoWUlJQyt1+zZg2uXr2KU6dOISgoCAAwcOBAtGrVCp988gmWLFmi7yYTEZGFY2wiIiJdMvgdrX/++Qd9+vRBREQEZs2ahZSUlDKvGgYHBwMATp06pbG8f//+uH79Oq5du6a39hIRkeVjbCIiIl0z6BgtlUqFWbNmYcqUKWjdunW59snPz0dUVBQ6duxYbF1wcDCuX7+OtLQ0XTeViIisBGMTERHpg0G7Dq5duxa3bt3CwYMHy73P/fv3oVQq4eXlVWxdwbL4+Hg0bdpU6/5KpRJKpVL9PD8/H/fv30etWrUgk8kq+AqIiKiyJElCWloavL29YWNjOkVvDR2bGJeIiEyHPmOTwRKte/fu4Z133sGCBQtQu3btcu+XlZUFALC1tS22zs7OTmMbbZYuXYpFixZVsLVERKQvd+7cga+vr7GbAcA4sYlxiYjI9OgjNhks0Xr77bfh7u6OWbNmVWg/e3t7ANC4+lcgOztbYxtt5s+fj9mzZ6ufKxQK1KtXD3fu3IGLi0uF2kJERJWXmpoKPz8/ODs7G7spasaITYxLRESmQ5+xySCJ1tWrV7F+/Xp8+umniI+PVy/Pzs5Gbm4ubt68CRcXF7i7uxfb193dHba2tkhISCi2rmCZt7d3iee2tbXVesXRxcWFAY2IyAhMpXucsWIT4xIRkenRR2wySCf5uLg45Ofn46WXXkLDhg3VP6Ghobhy5QoaNmyIxYsXa2+gjQ1at26NM2fOFFsXGhoKf39/k7o6SkRE5oGxiYiI9Mkgd7RatWqFHTt2FFv+9ttvIy0tDZ999hkaNWoEALh9+zYyMzPRrFkz9XajRo3CG2+8gTNnzqgrPF2+fBl///035s6da4iXQEREFoaxiYiI9Mng82g9qlevXsXmKunVqxeOHDmCR5uVlpaGdu3aIS0tDXPnzkX16tWxYsUKqFQqnD17tkIDmFNTU+Hq6gqFQsEuGkREBmQun7+Gjk3m8nshIrJE+vwMNmh598pydnbG4cOH8eqrr+L9999Hfn4+evXqhZUrV1YoySIiItIVxiYiIiqNUe9oGQOvHBIRGQc/f7Xj74WIyHj0+RlsOjNGEhERERERWQiz6DpoCnJzc6FSqYzdDDIjcrkc1atXN3YziMiCMTaRNWA8JXPFRKsMqampSElJ0TopJVFZbG1t4eHhwe5ARKRTjE1kbRhPyRwx0SpFamoq4uLi4OTkBA8PD1SvXt1kJtok0yZJEnJzc6FQKBAXFwcADA5EpBOMTWRNGE/JnDHRKkVKSgqcnJzg6+vLIEYVZm9vD2dnZ8TGxiIlJYWBgYh0grGJrA3jKZkrFsMoQW5uLpRKJVxdXRnIqNJkMhlcXV2hVCqRm5tr7OYQkZljbCJrxXhK5oiJVgkKBhdz8CVVVcF7iAPWiaiqGJvImjGekrlholUGXjGkquJ7iIh0jZ8rZI34vidzw0SLiIiIiIhIx5hoERERERER6RgTLSIiIiIiIh1jokWVcuzYMchkMjg7O2PIkCFISUkxdpMAAFlZWXjnnXfQpEkT2NnZwdvbG88//7x67g0iIrJMjEtEZGqYaFGlSJKE8ePHo0aNGtizZw8WL15s7CYhOzsbffr0wXvvvYf09HQMHz4cfn5++Pbbb9GuXTvcuHHD2E0kIiI9YVwiIlPDRIsqpVu3bti8eTM2bdoEQFxJNLb3338fJ0+eROfOnXHlyhX8/PPPCA0NxSeffILk5GQ8//zzxm4iERHpCeMSEZkaJlpUJT169IBMJsPFixeNOq9FTk4OVq9eDQD44osv4OTkpF43e/ZsBAYG4siRIwgLCzNWE4mIyAAYl4jIVDDRoipxdnZGvXr1kJ2djWvXrhmtHceOHYNCoUCjRo3Qrl27YutHjRoFANi1a5ehm0ZERAbEuEREpoKJFlXJ+fPnER8fDwCIiooyWjsiIyMBAO3bt9e6vmC5MdtIRET6x7hERKaCiRZVWn5+PqZMmYLc3FwAwLlz58q9b69evSCTySr0s3HjxhKPd/v2bQCAr6+v1vUFy2/dulXuNhIRkXlhXCIiU1LN2A0wV5mZQHS0sVtRfs2aAQ4Ouj3mZ599htDQUPj6+iI2NrZCAW3AgAFo0KBBhc7XuHHjEtelp6cDABxKeJGOjo4AgLS0tAqdk4jInFh7bGJcIiJTwkSrkqKjgQ4djN2K8gsLA0rovVApMTExWLBgATw9PbFlyxb07NmzQgHtjTfe0F1jiIgIgHXHJsYlIjI1TLQqqVkzESDMRbNmuj3etGnTkJGRgU2bNqFLly6wtbXFjRs3kJGRob5KZ0gF1ZwyMzO1rs/IyAAgBkkTEVkqa45NjEtEZGqYaFWSg4Nu7xCZk40bN+LAgQMYNWoURo4cCQBo0aIFIiIicP78eYSEhJR5jA8//BDRFezfMmXKFHTr1k3runr16gEAYmNjta4vWF6/fv0KnZOIyJxYa2xiXCIiU8REiyokKSkJc+bMgbu7u3p+EABo06YNIiIicO7cuXIFtD/++ANHjhyp0Ll79epVYkBr06YNACA8PFzr+oLlgYGBFTonERGZNsYlIjJVTLSoQmbNmoX79+/j+++/R926ddXLCwJKefvDHz58WKft6tq1K1xdXXH9+nWcPXsWbdu21Vi/bds2AMDQoUN1el4iIjIuxiUiMlUs707ltnPnTmzduhWDBg3ChAkTNNZVNKDpWo0aNTBz5kwAwIwZM9R93wFgxYoViIqKQs+ePdHBnEaJExFRqRiXiMiU8Y4WlYtCocD06dPh4uKCdevWFVtv7IAGAG+//TYOHjyI48ePIyAgAN27d8etW7cQGhqK2rVr45tvvjFa24iISLcYl4jI1PGOFpXLa6+9hvj4eHz00UdaJ190d3eHr68vUlJSkJCQYIQWAnZ2djh06BAWLFgABwcH/Pbbb7h16xYmTZqE8PBw+Pv7G6VdRESke4xLRGTqZJIkScZuhCGlpqbC1dUVCoUCLi4uJW6XnZ2NmJgYNGzYEHZ2dgZsIVkavpeIhPJ+/lqbivxe+HlC1ozvf9IHfcYm3tEiIiIiIiLSMSZaREREREREOsZEi4iIiIiISMeYaBEREREREekYEy0iIiIiIiIdY6JFRERERESkYwZLtC5cuIDRo0fD398fDg4O8PDwQI8ePbBr164y9924cSNkMpnWn8TERAO0noiILA3jEhER6VM1Q53o1q1bSEtLw7PPPgtvb29kZmbi119/xbBhw7Bu3TpMnTq1zGMsXrwYDRs21Fjm5uampxYTEZElY1wiIiJ9MliiNWjQIAwaNEhj2cyZM9GhQwesWLGiXAFt4MCB6Nixo76aSEREVoRxiYiI9MmoY7Tkcjn8/Pzw8OHDcu+TlpYGlUqlv0YREZHVYlwiIiJdMXiilZGRgZSUFFy/fh0rV67Evn370Ldv33Lt27t3b7i4uMDBwQHDhg3D1atX9dxaIiK
"text/plain": [
"<Figure size 1000x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.linear_model import Lasso\n",
"\n",
"plt.figure(figsize=(10,5))\n",
"plt.subplot(121)\n",
"plot_model(Lasso, polynomial=False, alphas=(0, 0.1, 1), random_state=42)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.subplot(122)\n",
"plot_model(Lasso, polynomial=True, alphas=(0, 10**-7, 1), tol=1, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Differentiability\n",
"\n",
"Note that the Lasso penality is non-differentiable at zero.\n",
"\n",
"Gradient descent can still be used but with gradients replaced by [sub-gradients](https://en.wikipedia.org/wiki/Subderivative) when any $\\theta_j=0$."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Elastic Net regularization\n",
"\n",
"\n",
"Provides a mix of Tikhonov and Lasso regularization, controlled by mix ratio $r$:\n",
"\n",
"$$R(\\theta)\n",
"= \n",
"r\\sum_{j=1}^n \\left\\vert \\theta_j \\right\\vert\n",
"+\n",
"\\frac{1-r}{2} \\sum_{j=1}^n \\theta_j^2.$$\n",
"\n",
"- For $r=0$, corresponds to Tikhonov regularization.\n",
"- For $r=1$, corresponds to Lasso regularization."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Stopping early\n",
"\n",
"Compute RMSE on validation set as train and stop when starts to increase."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:47.725983Z",
"iopub.status.busy": "2024-01-10T00:19:47.725495Z",
"iopub.status.idle": "2024-01-10T00:19:47.734570Z",
"shell.execute_reply": "2024-01-10T00:19:47.733897Z"
}
},
"outputs": [],
"source": [
"from sklearn.linear_model import SGDRegressor\n",
"np.random.seed(42)\n",
"m = 100\n",
"X = 6 * np.random.rand(m, 1) - 3\n",
"y = 2 + X + 0.5 * X**2 + np.random.randn(m, 1)\n",
"\n",
"X_train, X_val, y_train, y_val = train_test_split(X[:50], y[:50].ravel(), test_size=0.5, random_state=10)\n",
"\n",
"poly_scaler = Pipeline((\n",
" (\"poly_features\", PolynomialFeatures(degree=90, include_bias=False)),\n",
" (\"std_scaler\", StandardScaler()),\n",
" ))\n",
"\n",
"X_train_poly_scaled = poly_scaler.fit_transform(X_train)\n",
"X_val_poly_scaled = poly_scaler.transform(X_val)\n",
"\n",
"sgd_reg = SGDRegressor(max_iter=1,\n",
" penalty=None,\n",
" eta0=0.0005,\n",
" warm_start=True,\n",
" learning_rate=\"constant\",\n",
" random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:47.737763Z",
"iopub.status.busy": "2024-01-10T00:19:47.737190Z",
"iopub.status.idle": "2024-01-10T00:19:48.426308Z",
"shell.execute_reply": "2024-01-10T00:19:48.425611Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(action='ignore')\n",
"n_epochs = 500\n",
"train_errors, val_errors = [], []\n",
"for epoch in range(n_epochs):\n",
" sgd_reg.fit(X_train_poly_scaled, y_train)\n",
" y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n",
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
" train_errors.append(mean_squared_error(y_train_predict, y_train))\n",
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
"\n",
"best_epoch = np.argmin(val_errors)\n",
"best_val_rmse = np.sqrt(val_errors[best_epoch])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:48.429987Z",
"iopub.status.busy": "2024-01-10T00:19:48.429495Z",
"iopub.status.idle": "2024-01-10T00:19:48.654702Z",
"shell.execute_reply": "2024-01-10T00:19:48.653840Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1cAAAHJCAYAAABpMcPqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACH40lEQVR4nO3dd1gUxxsH8O/Rjg4iqICIXWONvfeuUaMSTIyxpBgTNRpL8tMkoqaYZokaTUyssSTWWGNsYOyxRo1ixQYoTQ6UDvv7Y3IHyx1wwB13B9/P8+zD7uzM7stlg7zM7IxCkiQJREREREREVCxWpg6AiIiIiIioNGByRUREREREZABMroiIiIiIiAyAyRUREREREZEBMLkiIiIiIiIyACZXREREREREBsDkioiIiIiIyACYXBERERERERmAjakDMFdZWVmIiIiAi4sLFAqFqcMhIiIiIiITkSQJiYmJ8PHxgZVV3v1TTK7yEBERAT8/P1OHQUREREREZuLBgweoXLlynueZXOXBxcUFgPgAXV1dTRwNERERERGZSkJCAvz8/DQ5Ql6YXOVBPRTQ1dWVyRURERERERX4uhAntCAiIiIiIjIAJldEREREREQGwOSKiIiIiIjIAJhcERERERERGQCTKyIiIiIiIgPgbIFEREREpVxmZibS09NNHQaR2bC2toatra3Br8vkioiIiKiUkiQJjx49gkqlgiRJpg6HyKwolUp4enoadNklJldEREREpZRKpUJ8fDy8vLzg5ORU4Bo9RGWBJElIT0+HSqVCeHg4ABgswWJyRURERFQKSZKEqKgouLq6wtPT09ThEJkVBwcHuLi44OHDh4iJiTFYcsUJLYiIiIhKoczMTGRmZhp0yBNRaaJQKODm5obU1FSDvZPI5IqIiIioFMrIyAAA2NhwoBJRXtSTWmRmZhrkekyuLERSEhAZaeooiIiIyNLwPSuivBn6/w8mV2YsOBgYORJo0ABwcQEmTTJ1RERERERElBf2E5ux27eBtWuzj8+dM10sRERERESUP/ZcmbFmzeTHt28D8fEmCYWIiIiIcli9ejUUCgVWr14tK69atSqqVq1a7OsY0qxZs6BQKBASEmK0e5DA5MqM1a8P2NnJy86fN00sRERERJZi2LBhUCgU2LhxY771EhIS4OjoCHd3dyQnJ5dQdIYXEhIChUKBWbNmmToUkzKHJJLJlRmzswMaNZKXcWggERERUf7eeOMNAMDKlSvzrbdx40YkJyfjlVdegYODg0HufejQIRw6dMgg1zKU8ePH49q1a2jZsqWpQyn1+M6VmWvWDDh7NvuYyRURERFR/rp27Ypq1arh8OHDuH//PqpUqaKznjr5UidjhlCjRg2DXctQPD09uZB0CWHPlZnL/d4VkysiIiKi/CkUCowePRpZWVlYtWqVzjr//vsv/v77bzRq1AjNmzeHSqXCV199hU6dOsHHxwd2dnbw8fHBiBEjcPv2bb3vndc7V3FxcRg7diwqVqwIR0dHtGjRAtu3b8/zOitXrsTAgQNRtWpV2Nvbw8PDA7169UJwcLCs3qxZs9ClSxcAwOzZs6FQKDTb3bt3NXXyGi63a9cudOnSBW5ubnBwcEDjxo0xf/58zTppanfv3oVCocCoUaNw69YtDBo0COXKlYOTkxO6d++Of/75R+/PSKVSYebMmahXrx6cnZ3h6uqKmjVrYuTIkbh3756sriRJWLlyJdq1awdXV1c4OjqiefPmWr2SnTt3xuzZswEAXbp00XwGhXn/zRDYc2XmcidXt26JSS3c3U0RDREREVmyrCwgNtbUUeivfHnAqohdAaNGjcKsWbOwevVqzJw5U2s9I3XSpe61unbtGmbOnIkuXbpg0KBBcHJyQmhoKDZs2IA9e/bg/Pnz8Pf3L1IsSUlJ6Ny5My5fvow2bdqgU6dOePDgAYYOHYqePXvqbDNu3Dg0btwY3bt3h5eXF8LDw/H777+je/fu2LZtGwYOHAhAJBV3797FmjVr0KlTJ3Tu3FlzDfcCfmGcP38+pkyZAg8PDwwbNgxOTk7YuXMnpkyZgqNHj2Lbtm1an9vdu3fRunVr1K9fH6+//jpu376NHTt2oEuXLrh27RoqVqyY7z0lSUKvXr1w+vRptGvXDr1794aVlRXu3buHnTt34rXXXtN8zpIk4dVXX8XGjRtRq1YtDBs2DHZ2djhw4ADeeOMNXL16Fd9++y0A8d8bAI4cOYKRI0dqkqqCPgODk0gnlUolAZBUKpVJ40hNlSQ7O0kCsrdDh0waEhEREVmA5ORk6erVq1JycrKmLCpK/juFuW9RUcX7DHr37i0BkA4ePCgrT09PlypWrCgplUopNjZWkiRJio+P1+zndPjwYcnKykp68803ZeWrVq2SAEirVq2Slfv7+0v+/v6ysqCgIAmA9NZbb8nK9+3bJwHQeZ07d+5oxRIRESH5+PhItWrVkpUHBwdLAKSgoCCtNjnvHxwcrCm7deuWZGNjI1WoUEG6f/++pjwlJUVq3769BEBau3atpjwsLEwT65dffim7/scffywBkObOnavz/jldunRJAiC9+OKLWudSUlKkxMREzfHy5cslANLo0aOltLQ0TXlqaqrUv39/CYB09uzZfL/Pguj6/0QXfXMDDgs0c3Z2QMOG8jIODSQiIiIqWF4TW+zevRuPHz/GwIED4eHhAQBwc3PT7OfUpUsX1K9fHwcPHixyHGvXroWdnR3mzJkjK+/Vqxe6deums021atW0yry9vTFkyBDcvHlTa/hcYW3YsAEZGRmYMmUK/Pz8NOVKpRJfffUVAOicHr5atWqYNm2arEz9OZ85c0bv++uaQESpVMLZ2VlzvGTJEjg5OeH777+Hra2tptzOzg6ff/45ABQ4I2RJ47BAC9CsmTyhYnJFREREVLCBAwfCy8sL27dvh0qlgpubG4C8J7IICQnBwoULcfr0acTExMjeO7LLvT6OnhISEhAWFoZ69eqhUqVKWuc7dOigc3bBO3fuYO7cuTh8+DDCw8ORmpoqOx8REVHkYYoAcOHCBQCQDSNUa9OmDezt7XHx4kWtc88//zysco3VrFy5MgAgXo8FWZ977jk0atQIGzduxMOHD/Hiiy+ic+fOWtdNSkrC5cuX4ePjo0n2ckpPTwcAhIaGFnjPksTkygJwUgsiIiKiwrO1tcVrr72G+fPnY8OGDXjnnXfw6NEj/PHHH6hSpQq6d++uqbt582YMHToUzs7O6NWrF6pWrQpHR0fNAr9F7SlKSEgAAFSoUEHneV3vKN26dQstW7ZEQkICunTpgv79+8PV1RVWVlYICQnBkSNHtJKtosal6/4KhQIVK1ZEeHi41jlXV1etMhsbkVJkZmYWeF8bGxscPnwYs2bNwtatWzFlyhQAgJeXF8aPH4+PPvoI1tbWePLkCSRJQnh4uGaiCl2ePXtW4D1LEpMrC9C8ufz41i1ApQL+++MLERERkV7Klweiokwdhf7Kly/+Nd544w3Mnz8fK1aswDvvvINffvkFGRkZGD16tKynZNasWbC3t8e5c+dQq1Yt2TV+/fXXIt9fnYxE5fHBP378WKtswYIFePLkCX755RcMHz5cdm7s2LE4cuRIkePJHdfjx4+1esAkScLjx491JlKGUL58eSxevBiLFi1CaGgoDh8+jMWLFyMoKAi2traYPn265t7NmjXD2ZzrEpk5JlcWoEED8e5VWlp22fnzwH+zbhIRERHpxcoK8PIydRQlq169emjdujVOnTqFS5cuYdWqVZqp2nO6ffs26tevr5VYRUZG4s6dO0W+v6urK6pVq4Zbt27h0aNHWkMDjx49qtVGPfW7ekZANUmScPz4ca361tbWAPTrOVJr0qQJtm/fjpCQEK3FhU+fPo2UlBS0bdtW7+sVhUKhwHPPPYfnnnsOAwYMQJUqVbBz505Mnz4dLi4ueO6553Dt2jXEx8frNetfUT4HQ+OEFhaAk1oQERERFZ363ap3330X165dQ/fu3bV6a/z9/XHr1i1ZT1JKSgreeecdzfs9RfXaa68hLS0NM2fOlJXv379f5/tW6tiOHTsmK//yyy9
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,5))\n",
"plt.annotate('Best model',\n",
" xy=(best_epoch, best_val_rmse),\n",
" xytext=(best_epoch, best_val_rmse + 1),\n",
" ha=\"center\",\n",
" arrowprops=dict(facecolor='black', shrink=0.05),\n",
" fontsize=16,\n",
" )\n",
"\n",
"best_val_rmse -= 0.03 # just to make the graph look better\n",
"plt.plot([0, n_epochs], [best_val_rmse, best_val_rmse], \"k:\", linewidth=2)\n",
"plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"Validation set\")\n",
"plt.plot(np.sqrt(train_errors), \"r--\", linewidth=2, label=\"Training set\")\n",
"plt.legend(loc=\"upper right\", fontsize=14)\n",
"plt.xlabel(\"Epoch\", fontsize=14)\n",
"plt.ylabel(\"RMSE\", fontsize=14);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Note that for stochastic or mini-batch gradient descent the RMSE is noisy and may be difficult to know when reach minimum (can employ simple strategies to deal with that)."
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}