spce0038-machine-learning-w.../week2/slides/Lecture06_TrainingII.ipynb

659 lines
223 KiB
Plaintext
Raw Normal View History

2025-02-22 19:16:55 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Lecture 6: Training II"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"source": [
"![](https://www.tensorflow.org/images/colab_logo_32px.png)\n",
"[Run in colab](https://colab.research.google.com/drive/18org-B5m6ZtN7E9GBSoLtkXb6-Fc2bgh)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:34.924839Z",
"iopub.status.busy": "2024-01-10T00:19:34.924473Z",
"iopub.status.idle": "2024-01-10T00:19:34.932173Z",
"shell.execute_reply": "2024-01-10T00:19:34.931501Z"
},
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last executed: 2024-01-10 00:19:34\n"
]
}
],
"source": [
"import datetime\n",
"now = datetime.datetime.now()\n",
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Stochastic gradient descent"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Problems with batch gradient descent\n",
"\n",
"- Uses the entire training set to compute gradients at every step (slow when the training set is large).\n",
"\n",
"- Full training set needs to be held in memory."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Properties of stochastic gradient descent\n",
"\n",
"- Uses a (random) single instance from the training set to compute gradients at each iteration (fast since very little data considered for each iteration).\n",
"- Only one instance of training data then needs to be held in memory.\n",
"- Less regular than batch gradient descent.\n",
" - Helps to escape local minima.\n",
" - Ends up close to a minimum but continues to explore vacinity around minimum (\"bounces\" around)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Simulated annealing\n",
"\n",
"To mitigate issue of bouncing around minimum, can reduce learning rate as algorithm proceeds.\n",
"\n",
"Called *simulated annealing* by analogy with annealing in metallurgy.\n",
"\n",
"*Learning schedule* defines how learning rate changes over time.\n",
"\n",
"- If learning rate reduces too quickly, may get stuck on local minimum or end up frozen half-way to minimum.\n",
"- If learning rate reduces too slowly, may jump around minimum for long time."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Example learning schedule\n",
"\n",
"Set learning rate $\\alpha$ at iteration $t$ by\n",
"\n",
"$$\\alpha(t) = \\frac{t_0}{t + t_1},$$\n",
"\n",
"where $t_0$ and $t_1$ are parameters."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Stochastic gradient descent example"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:34.974213Z",
"iopub.status.busy": "2024-01-10T00:19:34.973677Z",
"iopub.status.idle": "2024-01-10T00:19:35.388064Z",
"shell.execute_reply": "2024-01-10T00:19:35.387394Z"
}
},
"outputs": [],
"source": [
"# Common imports\n",
"import os\n",
"import numpy as np\n",
"np.random.seed(42) # To make this notebook's output stable across runs\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Set up training data (repeating example from previous lecture)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.392076Z",
"iopub.status.busy": "2024-01-10T00:19:35.391426Z",
"iopub.status.idle": "2024-01-10T00:19:35.624448Z",
"shell.execute_reply": "2024-01-10T00:19:35.623816Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAw0AAAIbCAYAAACpGXLSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCYUlEQVR4nO3deXxU5b3H8e8kSFgTiqAsCYsQtbKjgoDFaKvBIrgCVaxQYoMWwRUpFS8KAqVaV1TE5iVWpepFvbXiragXN4oogtZesVBtMENAZMvIkgDJc/+Ym8BkOZlMZuZsn/frlRevnJzJPHPmZHi+5/k9zwkYY4wAAAAAoA4pdjcAAAAAgLMRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlprY3YBEqaioUHFxsVq3bq1AIGB3cwAAAIC4MMbo+++/V6dOnZSSkpwxAM+GhuLiYmVlZdndDAAAACAhioqKlJmZmZTn8mxoaN26taTwwUxPT7e5NQAAAEB8hEIhZWVlVfV3k8GzoaGyJCk9PZ3QAAAAAM9JZgk+E6EBAAAAWCI0AAAAALBEaAAAAABgidAAAAAAwBKhAQAAAIAlQgMAAAAAS4QGAAAAAJYIDQAAAAAsERoAAAAAWCI0AAAAALBEaAAAAABgidAAAAAAwBKhAQAAAIAlQgMAAAAAS4QGAAAAAJYIDQAAAAAsERoAAAAAWCI0AAAAALBEaAAAAABgidAAAAAAwBKhAQAAAIAlQgMAAAAAS0kLDfv27dPs2bM1YsQItW3bVoFAQEuXLrV8zOHDh3XaaacpEAjovvvuS05DAQAAAERIWmjYuXOn5syZo40bN6pfv35RPeaRRx7RN998k+CWAQAAALCStNDQsWNHbdu2TVu2bNG9995b7/47duzQnDlzNGPGjCS0DgAAAEBdkhYa0tLS1KFDh6j3//Wvf61TTjlFV199dQJbBQAAAKA+TexuQG0++ugjPf300/rggw8UCATsbg4AAADga44LDcYYTZ06VePGjdOQIUNUWFgY1ePKyspUVlZW9X0oFEpQCwEAAAB/cdySq0uXLtXnn3+uhQsXNuhxCxYsUEZGRtVXVlZWgloIAAAA+IujQkMoFNLMmTM1ffr0Bnf6Z86cqZKSkqqvoqKiBLUSAAAA8BdHlSfdd999OnTokMaNG1dVlhQMBiVJe/bsUWFhoTp16qSmTZvWeGxaWprS0tKS2VwAAADAFxw10vDNN99oz5496tWrl7p3767u3bvrRz/6kSRp/vz56t69u7744gubWwkAAAD4i6NGGqZNm6ZLLrkkYtuOHTs0efJkTZw4URdffLG6d+9uT+MAAAAAn0pqaFi0aJH27t2r4uJiSdJf/vKXqvKjqVOnauDAgRo4cGDEYyrLlHr16lUjUAAAAABIvKSGhvvuu09btmyp+v7ll1/Wyy+/LEm6+uqrlZGRkczmAAAAAIhCUkNDtPdcOFa3bt1kjIl/YwAAAABExVEToQEAAAA4D6EBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsJS00LBv3z7Nnj1bI0aMUNu2bRUIBLR06dKIfSoqKrR06VKNHj1aWVlZatmypXr37q177rlHpaWlyWoqAAAAgGMkLTTs3LlTc+bM0caNG9WvX79a9zlw4IB+8Ytf6LvvvtN1112nBx98UIMGDdLs2bN14YUXyhiTrOYCAAAA+H9NkvVEHTt21LZt29ShQwetW7dOZ555Zo19mjZtqtWrV2vo0KFV2375y1+qW7dumj17tt5++2395Cc/SVaTAQAAACiJIw1paWnq0KGD5T5NmzaNCAyVLr30UknSxo0bE9I2AAAAAHVzxUTo7du3S5LatWtnc0sAAAAA/0laeVJj/O53v1N6erouvPDCOvcpKytTWVlZ1fehUCgZTQMAAAA8z/EjDfPnz9dbb72l3/72t2rTpk2d+y1YsEAZGRlVX1lZWclrJAAAAOBhjg4NL7zwgmbNmqW8vDxdf/31lvvOnDlTJSUlVV9FRUVJaiUAAADgbY4tT3rzzTd1zTXXaOTIkVq8eHG9+6elpSktLS0JLQMAAAD8xZEjDWvXrtWll16qM844Qy+++KKaNHFstgEAAAA8z3GhYePGjRo5cqS6deum1157Tc2bN7e7SQAAAICvJfUS/qJFi7R3714VFxdLkv7yl78oGAxKkqZOnaqUlBTl5uZqz549mj59ulasWBHx+B49emjIkCHJbDIAAADgewFjjEnWk3Xr1k1btmyp9Wf//ve/JUndu3ev8/ETJkzQ0qVLo3quUCikjIwMlZSUKD09vcFtBQAAAJzIjn5uUkcaCgsL690niRkGAAAAQBQcN6cBAAAAgLMQGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAAGCJ0AAAAADAEqEBAAAAgCVCAwAAAABLhAYAAAAAlggNAAAAACwRGgAAAABYIjQAAAAAsERoAAAAQIMEg9KqVeF/4Q+EBgAAAEStoEDq2lU677zwvwUFdrcIyUBoAAAAQFSCQSk/X6qoCH9fUSFNnsyIgx8QGgAAABCVzZuPBoZK5eXSv/5lT3uQPIQGAAAARCU7W0qp1ntMTZV69rSnPUgeQgMAAACikpkpLVkSDgpS+N8nnghvh7c1sbsBAAAAcI+8PCk3N1yS1LMngcEvCA0AAABokMxMwoLfUJ4EAAAAwBKhAQAAAIAlQgMAAADgEE692zahAQAAAHAAJ99tm9AAAAAA2Mzpd9smNAAAAAA2c/rdtgkNAAAAgM2cfrdtQgMAAABgM6ffbZubuwEAAAAO4OS7bRMaAAAAAIdw6t22KU8CAAAAYInQAAAAAMASoQEAAACAJUIDAAAAAEuEBgAAAACWCA0AAABIumBQWrUq/C+cj9AAAACApCookLp2lc47L/xvQYHdLUJ9CA0AAABImmBQys+XKirC31dUSJMnN3zEgZGK5CI0AAAAx6FD6F2bNx8NDJXKy8N3QY4WIxXJR2gAAACOQofQ27KzpZRqPdDUVKlnz+geH6+RimTxSgAmNAAAAMdwW4cQDZeZKS1ZEg4KUvjfJ54Ib49GPEYqksVLAThpoWHfvn2aPXu2RowYobZt2yoQCGjp0qW17rtx40aNGDFCrVq1Utu2bfXzn/9c3333XbKaCgAAbOKmDiFil5cnFRaGr8AXFoa/j1ZjRyqSxWsBOGmhYefOnZozZ442btyofv361blfMBjU8OHD9a9//Uvz58/XbbfdphUrVuj888/XoUOHktVcAABgA7d0CNF4mZlSTk70IwzHPq4xIxXJ4rUA3CRZT9SxY0dt27ZNHTp00Lp163TmmWfWut/8+fO1f/9+ffLJJ+rSpYskadCgQTr//PO1dOlS5efnJ6vJAAAgySo7hJMnhztYTu0Qwl55eVJubrgD3rOnM8+PygB8bHBwcwBO2khDWlqaOnToUO9+L730ki666KKqwCBJP/nJT3TyySfrxRdfTGQTAQCAAzSmdAX+EetIRbK4ZUQkWkkbaYjG1q1btWPHDp1xxhk1fjZo0CC9/vrrNrQKAAAkW2ameztXQCU3jIhEy1GhYdu
"text/plain": [
"<Figure size 900x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"m = 100\n",
"X = 2 * np.random.rand(m, 1)\n",
"y = 4 + 3 * X + np.random.randn(m, 1)\n",
"plt.figure(figsize=(9,6))\n",
"plt.plot(X, y, \"b.\")\n",
"plt.xlabel(\"$x_1$\", fontsize=18)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.axis([0, 2, 0, 15]);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Add bias terms"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.627965Z",
"iopub.status.busy": "2024-01-10T00:19:35.627484Z",
"iopub.status.idle": "2024-01-10T00:19:35.631843Z",
"shell.execute_reply": "2024-01-10T00:19:35.631191Z"
}
},
"outputs": [],
"source": [
"X_b = np.c_[np.ones((m, 1)), X] # add x0 = 1 to each instance\n",
"X_new = np.array([[0], [2]])\n",
"X_new_b = np.c_[np.ones((2, 1)), X_new] # add x0 = 1 to each instance"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"#### Solve by SGD with learning schedule"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.635312Z",
"iopub.status.busy": "2024-01-10T00:19:35.635050Z",
"iopub.status.idle": "2024-01-10T00:19:35.639860Z",
"shell.execute_reply": "2024-01-10T00:19:35.639206Z"
}
},
"outputs": [],
"source": [
"theta_path_sgd = []\n",
"m = len(X_b)\n",
"np.random.seed(42)\n",
"\n",
"n_epochs = 50\n",
"t0, t1 = 5, 50 # learning schedule hyperparameters\n",
"\n",
"def learning_schedule(t):\n",
" return t0 / (t + t1)\n",
"\n",
"theta = np.random.randn(2,1) # random initialization"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.642955Z",
"iopub.status.busy": "2024-01-10T00:19:35.642348Z",
"iopub.status.idle": "2024-01-10T00:19:35.887038Z",
"shell.execute_reply": "2024-01-10T00:19:35.886346Z"
},
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAw0AAAIbCAYAAACpGXLSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACfZklEQVR4nO3dd3hUZfYH8O9MeghJCD0ECITQQugdpEiXoiBFse7ioq5l113Luupid911Xfenrm1ZsYCC2BBRiiIgVUSE0KWHXhMCpM79/XH2cufeuXeYmUym5ft5njxA5s7MzTBJ3nPPOe+xKYqigIiIiIiIyII92CdAREREREShjUEDERERERG5xaCBiIiIiIjcYtBARERERERuMWggIiIiIiK3GDQQEREREZFbDBqIiIiIiMgtBg1ERERERORWdLBPoKo4HA4cPnwYNWvWhM1mC/bpEBERERH5haIoOHfuHNLT02G3ByYHELFBw+HDh9G4ceNgnwYRERERUZU4ePAgMjIyAvJcERs01KxZE4C8mMnJyUE+GyIiIiIi/ygsLETjxo0vrXcDIWKDBrUkKTk5mUEDEREREUWcQJbgsxGaiIiIiIjcYtBARERERERuMWggIiIiIiK3GDQQEREREZFbDBqIiIiIiMgtBg1EREREROQWgwYiIiIiInKLQQMREREREbnFoIGIiIiIiNxi0EBERERERG4xaCAiIiIiIrcYNBARERERkVsMGoiIiIiIyC0GDURERERE5BaDBiIiIiIicotBAxERERERucWggYiIiIiI3GLQQEREREREbjFoICIiIiIitxg0EBERERGRWwwaiIiIiIjILQYNRERERETkVsCChqKiIkybNg3Dhw9HWloabDYbZsyY4fY+ZWVlaNu2LWw2G1544YXAnCgREREREekELGg4efIknnzySWzbtg0dOnTw6D4vv/wyDhw4UMVnRkRERERE7gQsaGjYsCGOHDmC/fv34+9///tljz9+/DiefPJJPPTQQwE4OyIiIiIishKwoCEuLg4NGjTw+Pg//elPaNWqFW688cYqPCsiIiIiIrqc6GCfgJl169bhnXfewffffw+bzRbs0yEiIiIiqtZCLmhQFAX33HMPJk2ahF69emHfvn0e3a+kpAQlJSWX/l1YWFhFZ0hEREREVL2E3JarM2bMwObNm/H88897db/nnnsOKSkplz4aN25cRWdIRERERFS9hFTQUFhYiIcffhgPPPCA14v+hx9+GAUFBZc+Dh48WEVnSURERERUvYRUedILL7yA0tJSTJo06VJZUn5+PgDgzJkz2LdvH9LT0xEbG+ty37i4OMTFxQXydImIiIiIqoWQyjQcOHAAZ86cQU5ODpo1a4ZmzZrhiiuuAAA8++yzaNasGbZu3RrksyQiIiIiql5CKtNw77334pprrtF97vjx47j99ttx66234uqrr0azZs2Cc3JERERERNVUQIOGV155BWfPnsXhw4cBAF988cWl8qN77rkHnTt3RufOnXX3UcuUcnJyXAIKIiIiIiKqegENGl544QXs37//0r8/+eQTfPLJJwCAG2+8ESkpKYE8HSIiIiIi8kBAgwZPZy44y8zMhKIo/j8ZIiIiIiLySEg1QhMRERERUehh0EBERERERG4xaCAiIiIiIrcYNBARERERkVsMGoiIiIiIyC0GDURERERE5BaDBiIiIiIicotBAxERERERucWggYiIiIiI3GLQQEREREREbjFoICIiIiIitxg0EBERERGRWwwaiIiIiIjILQYNRERERETkFoMGIiIiIiJyi0EDERERERG5xaCBiIiIiIjcYtBARERERERuMWggIiIiIiK3GDQQEREREZFbDBqIiIiIiMgtBg1EREREROQWgwYiIiIiInKLQQMREREREbnFoIGIiIiIiNxi0EBERERERG4xaCAiIiIiIrcYNBARERERkVsMGoiIiIiIyC0GDURERERE5BaDBiIiIiIicotBAxERERERucWggYiIiIiI3GLQQEREREREbjFoICIiIiIitxg0EBERERGRWwwaiIiIiIjILQYNRERERETkFoMGIiIiIiJyi0EDERERERG5xaCBiIiIiIjcYtBARERERERuMWggIiIiIiK3GDQQEREREZFbDBqIiIiIiMgtBg1EREREROQWgwYiIiIiInIrYEFDUVERpk2bhuHDhyMtLQ02mw0zZszQHeNwODBjxgyMGTMGjRs3Ro0aNdCuXTs8/fTTKC4uDtSpEhERERGRk4AFDSdPnsSTTz6Jbdu2oUOHDqbHXLhwAb/61a9w4sQJ3HHHHXjppZfQvXt3TJs2DSNGjICiKIE6XSIiIiIi+p/oQD1Rw4YNceTIETRo0ADr169Ht27dXI6JjY3FypUr0bt370uf+81vfoPMzExMmzYN33zzDQYPHhyoUyYiIiIiIgQw0xAXF4cGDRq4PSY2NlYXMKjGjh0LANi2bVuVnBsREREREVkLi0boo0ePAgDq1KkT5DMhIiIiIqp+AlaeVBl/+9vfkJycjBEjRlgeU1JSgpKSkkv/LiwsDMSpERERERFFvJDPNDz77LNYsmQJ/vrXvyI1NdXyuOeeew4pKSmXPho3bhy4kyQiIiIiimAhHTTMnj0bjz76KKZMmYI777zT7bEPP/wwCgoKLn0cPHgwQGdJRERERBTZQrY8afHixbj55psxcuRIvP7665c9Pi4uDnFxcQE4MyIiIiKi6iUkMw1r167F2LFj0bVrV8yZMwfR0SEb2xARERERRbyQCxq2bduGkSNHIjMzE/Pnz0dCQkKwT4mIiIiIqFoL6CX8V155BWfPnsXhw4cBAF988QXy8/MBAPfccw/sdjuGDRuGM2fO4IEHHsCXX36pu39WVhZ69eoVyFMmIiIiIqr2bIqiKIF6sszMTOzfv9/0tr179wIAmjVrZnn/W265BTNmzPDouQoLC5GSkoKCggIkJyd7fa5ERERERKEoGOvcgGYa9u3bd9ljAhjDEBERERGRB0Kup4GIiIiIiEILgwYiIiIiInKLQQMREREREbnFoIGIiIiIiNxi0EBERERERG4xaCAiIiIiIrcYNBARERERkVsMGoiIiIiIyC0GDURERERE5BaDBiIiIiIicotBAxERERERucWggYiIiIiI3GLQQEREREREbjFoICIiIiIitxg0EBERERGRWwwaiIiIiIjILQYNRERERETkFoMGIiIiIiJyi0EDERERERG5xaCBiIiIiLySnw8sXSp/UvXAoIGIiIiIPDZ9OtC0KXDllfLn9OnBPiMKBAYNREREROSR/Hxg6lTA4ZB/OxzA7bcz41AdMGggIiIiIo/s2qUFDKqKCuCXX4JzPhQ4DBqIiIiIyCPZ2YDdsHqMigJatAjO+VDgMGggIiIiIo9kZABvvimBAiB/vvGGfJ4iW3SwT4CIiIiIwseUKcCwYVKS1KIFA4bqgkEDEREREXklI4PBQnXD8iQiIiIiInKLQQMREREREbnFoIGIiIiIKESE6rRtBg1ERERERCEglKdtM2ggIiIiIgqyUJ+2zaCBiIiIiCjIQn3aNoMGIiIiIqIgC/Vp2wwaiIiIiIiCLNSnbXO4GxERERFRCAjladsMGoiIiIiIQkSoTttmeRIREREREbnFoIGIiIiIiNxi0EBERERERG4xaCAiIiIiIrcYNBARERERkVsMGoiIiIgo4PLzgaVL5U8KfQwaiIiIiCigpk8HmjYFrrxS/pw+PdhnRJfDoIGIiIiIAiY/H5g6FXA45N8OB3D77d5nHJipCCwGDURERBRyuCCMXLt2aQGDqqJCpiB7qrpmKi5cAL76CnjwwcA/t01RFCXwT1v1CgsLkZKSgoKCAiQnJwf7dIiIiMhD06drV6LtduDNN4EpU4J9VuQv+fmy0HcOHKKigH37PJuEXNn7B1p+vgRK2dnen5+iANu3A19/LR/LlgElJQBQCCCw61xmGoiIiChk+Kt0hUJXRoYEglFR8u+oKOCNNzxfUPsjUxEovmRECguBTz+V931mJtC2LfCHPwCLFknA0LgxcMstVX7qLgIWNBQVFWH
"text/plain": [
"<Figure size 900x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(9,6))\n",
"for epoch in range(n_epochs):\n",
" for i in range(m):\n",
" \n",
" # Plot current model\n",
" if epoch == 0 and i < 10: \n",
" y_predict = X_new_b.dot(theta) \n",
" style = \"b-\" if i > 0 else \"r--\" \n",
" plt.plot(X_new, y_predict, style) \n",
" \n",
" # SGD update\n",
" random_index = np.random.randint(m)\n",
" xi = X_b[random_index:random_index+1]\n",
" yi = y[random_index:random_index+1]\n",
" gradients = 2 * xi.T.dot(xi.dot(theta) - yi)\n",
" alpha = learning_schedule(epoch * m + i) \n",
" theta = theta - alpha * gradients\n",
" theta_path_sgd.append(theta) \n",
"\n",
"plt.plot(X, y, \"b.\") \n",
"plt.xlabel(\"$x_1$\", fontsize=18) \n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18) \n",
"plt.axis([0, 2, 0, 15]);"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.890139Z",
"iopub.status.busy": "2024-01-10T00:19:35.889906Z",
"iopub.status.idle": "2024-01-10T00:19:35.895764Z",
"shell.execute_reply": "2024-01-10T00:19:35.895204Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[4.21076011],\n",
" [2.74856079]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"Use only 50 passes over the data, compared to 1000 for batch gradient descent."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Mini-batch gradient descent"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"\n",
"Use *mini-batches* of small random sets of instances of training data.\n",
"\n",
"Trades off properties of batch GD and stochastic GD.\n",
"\n",
"Can get a performance boost over SGD by exploiting hardware optimisation for matrix operations, particuarly for GPUs."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Shuffling training data\n",
"\n",
"First step is to randomly shuffle or reorder data-set since do not want to be sensitive to ordering of data (want mini-batch considered to be representative)."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 2 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Comparing gradient descent algorithms"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Batch gradient descent (from previous lecture)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.899258Z",
"iopub.status.busy": "2024-01-10T00:19:35.898708Z",
"iopub.status.idle": "2024-01-10T00:19:35.905654Z",
"shell.execute_reply": "2024-01-10T00:19:35.905047Z"
}
},
"outputs": [],
"source": [
"theta_path_bgd = []\n",
"\n",
"def plot_gradient_descent(theta, alpha, theta_path=None):\n",
" m = len(X_b)\n",
" plt.plot(X, y, \"b.\")\n",
" n_iterations = 1000\n",
" for iteration in range(n_iterations):\n",
" if iteration < 10:\n",
" y_predict = X_new_b.dot(theta)\n",
" style = \"b-\" if iteration > 0 else \"r--\"\n",
" plt.plot(X_new, y_predict, style)\n",
" gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)\n",
" theta = theta - alpha * gradients\n",
" if theta_path is not None:\n",
" theta_path.append(theta)\n",
" plt.xlabel(\"$x_1$\", fontsize=18)\n",
" plt.axis([0, 2, 0, 15])\n",
" plt.title(r\"$\\alpha = {}$\".format(alpha), fontsize=16)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:35.908611Z",
"iopub.status.busy": "2024-01-10T00:19:35.907981Z",
"iopub.status.idle": "2024-01-10T00:19:36.473185Z",
"shell.execute_reply": "2024-01-10T00:19:36.472575Z"
},
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAGZCAYAAABhShsgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAADNIElEQVR4nOydd3gU9fPH35eQhFASQi+h996rVEUQC6iAKCKgKGChKCCgKE1FRbGhFEHATvsi2KmiKEpHkCpSEnoJCS315vfH/Jbd63v99m5ez3NPLrfts5vbyWd2Zt5jIiKCIAiCIAiCIAiC4BFRwR6AIAiCIAiCIAiCkRGnShAEQRAEQRAEwQvEqRIEQRAEQRAEQfACcaoEQRAEQRAEQRC8QJwqQRAEQRAEQRAELxCnShAEQRAEQRAEwQvEqRIEQRAEQRAEQfACcaoEQRAEQRAEQRC8QJwqQRAEQRAEQRAELxCnSjAES5cuRceOHZGUlISCBQuiYcOGePPNN5GTkxOw/ebk5GDdunUYM2YMmjdvjiJFiiAmJgalS5dG9+7d8f3333s1FkEQgoOv7cvBgwfxwQcfYODAgahfvz7y5csHk8mEV155xccjFwQhVPC1HRk4cCBMJpPTV2Zmpo/PQvAGExFRsAchCM4YOXIk3nvvPeTLlw+33norChUqhPXr1+Py5cto27YtVq9ejfj4eL/vd+3atbj99tsBAKVLl0bTpk1RsGBB7Nu3D3v37gUADB48GLNnz4bJZPLNyQuC4Ff8YV+UfVozdepUTJgwwVdDFwQhRPCHHRk4cCAWLVqEW265BdWqVbO7zscff4yYmBhfnILgC0gQQpgVK1YQACpUqBBt37795ufnz5+n+vXrEwAaNWpUQPa7bt066tmzJ/366682+/v6668pOjqaANCiRYvcHo8gCIHHX/bl448/ptGjR9MXX3xB+/fvp0ceeYQA0NSpU305fEEQQgB/2ZEBAwYQAFqwYIEPRyv4E3GqhJCmefPmBIBeeeUVm2W//fYbAaC4uDi6fPly0Pc7aNAgAkC33XabW2MRBCE4+Mu+WKNMjsSpEoTww192RJwq4yE1VQIOHz6MwYMHo2LFisifPz+qVauGyZMn38wDvv/++5E/f34cP348oOM6efIktm7dCgDo27evzfK2bduifPnyyMrKwg8//BD0/TZu3BgAkJKSonsbQQh3Is2+CILge8SOCEZAnKoIZ/78+WjQoAEWLFiAmjVr4o477sCpU6cwadIkvPfee9ixYwdWrFiBoUOHomLFigEd286dOwEARYsWReXKle2u06xZM4t1g7nfw4cPAwDKlCmjextBCGci0b4IguBbIt2ObNiwAaNGjcLgwYMxfvx4rFixAllZWZ4NWPAr+YI9ACF4LF++HE888QQSExOxdu1aNG3aFACwevVqdO3aFStXrsQvv/yCwoUL48UXX3S6L6Wg0l02bNiAjh072l129OhRAECFChUcbl++fHmLdfXgj/2eOXMGCxcuBAD07NlT91gEIVyJVPsiCILvEDsCfPrppzaflSlTBp988gnuuOMOj/Yp+AdxqiKUrKwsDBs2DESEGTNm3DRUANClSxcUKFAA27ZtQ2ZmJiZOnIgSJUo43V/btm09Gkfp0qUdLrty5QoAoGDBgg7XKVSoEAAgIyND9zF9vd/c3Fz069cP6enpqF+/PoYMGaJ7LIIQjkSyfREEwTdEuh1p2LAh3nvvPdx2222oUKECbty4gd27d2PSpEn4448/0L17d6xevdqhwycEHnGqIpQVK1bg9OnTqFOnDh599FGb5UlJSTh58iRKlCiBUaNGudzf448/jscff9wfQw15hg4dinXr1qFYsWJYtmwZYmNjgz0kQQgqYl8EQfCWSLcjzz77rMXvhQsXxu23347OnTvjvvvuw8qVKzFy5Ejs2rUrOAMUbJCaqghFKZjs1auX0/VeeOEFFC5cOBBDskE57rVr1xyuc/XqVQBAQkJCUPY7YsQIzJ8/H0lJSVizZg1q1KihexyCEK5Esn0RBME3iB2xj8lkwuTJkwEAu3fvFnGsEEIiVRGKolbToUMHm2U5OTm4fv06SpQogSeffFLX/ubNm4dNmza5PY5x48ahVq1adpdVqlQJgHM1PWWZsq4efLXfUaNG4f3330eRIkWwevXqm+p/ghDpRLJ9EQTBN4gdcUzt2rVvvk9NTb1ZtyUEF3GqIhRFdrRcuXI2y9555x2kpaWhbt26iIuL07W/TZs2eVQAOnDgQIfGSnFSLl68iKNHj9pV1tm2bRsAoEmTJrqP6Yv9Pv/885gxYwYSExOxevXqm+o+giBEtn0RBME3iB1xzMWLF2++D1aUTrBF0v8ilKgo/tNfvnzZ4vOjR49i6tSpAIDo6Gjd+1u4cCGIm0m79XJWYJmcnIzmzZsDAL788kub5Zs2bUJKSgri4uJw55136h6rt/sdN24cpk+fjsTERKxZs+bmvgRBYCLZvgiC4BvEjjjm66+/BsAphTVr1vTZfgUv8VdXYSG0adWqFQGgRx55hMxmMxERXbhwgRo3bkwmk4liYmKoaNGidP369aCOc8WKFQSAChUqRNu3b7/5+YULF6h+/foEgEaNGmV323HjxlHNmjVp3LhxPtvviy++SACoSJEitGXLFh+coSCEH5FuX6wZMGAAAaCpU6f6bOyCEO5Esh3ZuXMnrVy5knJyciw+z8vLo3nz5lH+/PkJAE2YMMH3JyR4jDhVEcrSpUsJAAGgRo0aUe/evalYsWIEgKZPn06tW7cmANSiRQt67733gjrW4cOHEwCKiYmhO+64g3r27ElFihQhAHTLLbc4NKjKRGbAgAE+2e/KlStvXrNmzZrRgAED7L4cGU9BiBQi3b5s376dWrZsefNVvHhxAkDJyckWn586dcrPZycIxiWS7YjiqCUlJdFtt91Gffv2pTvvvJMqVKhw85o89NBDNk6XEFzEqYpgvvrqK2revDkVKFCA8ufPT7Vr16avvvqKiIh27dpFDRs2JJPJRGPHjg3ySIkWL15M7du3p4SEBIqPj6d69erR66+/TllZWQ63ceVUubvfBQsW3DRmzl4VK1b0wRkLgrGJZPuyYcMGXbbi6NGj/jspQQgDItWO/PfffzRy5Ehq27YtlStXjvLnz09xcXFUoUIF6tWrF33//fd+PhvBE0xERN4nEQqCIAiCIAiCIEQmIlQhCIIgCIIgCILgBeJUCYIgCIIgCIIgeIE4VYIgCIIgCIIgCF4gTpUgCIIgCIIgCIIXiFMlCIIgCIIgCILgBeJUCYIgCIIgCIIgeEG+YA/AX5jNZpw6dQqFCxeGyWQK9nAEIaIhIly5cgVly5ZFVJRxnuWIHRGE0MGIdkRsiCCEDv62IWHrVJ06dQrly5cP9jAEQdCQkpKC5OTkYA9DN2JHBCH0MJIdERsiCKGHv2xI2DpVhQsXBsAXLiEhIcijEYTIJiMjA+XLl795XxoFsSOCEDoY0Y6IDbHko4+A8eOBIkWA7duB4sX1b7twITBiBBAXB/zxB1Ctmv5ts7KA224D9uwBbr0VWL4ccCdQceYM0LYtcP488PDDfB7u8MMPwEMP8fuPPuJ9CIHH3zYkbJ0qJcyekJAghkwQQgSjpb+IHRGE0MNIdkRsiEpqKvDqq/x++nSgShX92544AUyYwO9few1o0sS9Y48ezQ5V8eLA55+zU6eX3Fxg8GB2qOrXB+bOBQoU0L/9gQO8PQA88wzw5JNuDV3wA/6yIcZIShYEQRAEQRAMy8iRwNWrQOvWwGOP6d+OCHj8ceDKFaBNG45WucOaNcDbb/P7+fOBMmXc2/7ll4GNG4FChYClS91zqDIygHvv5bG3awfMmOHesQVjIU6VIAiCIAiC4Dd++IFT7qKjgdmz3Uu9mz+fHaP8+YFPPuF96OXCBWDAAH7/5JNA9+7ujfv774Fp09Rx1Kypf1uzGXjkEeDgQSA5mR2ymBj3ji8YC3GqBEEQBEEQBL9w/TqnvQHAs88CDRro3/bECeC55/j9K6+459QQAYMGAadPA7VrA2+9pX9bADh+nJ0igMf/wAPubf/
"text/plain": [
"<Figure size 1000x400 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"np.random.seed(42)\n",
"theta = np.random.randn(2,1) # random initialization\n",
"\n",
"plt.figure(figsize=(10,4))\n",
"plt.subplot(131); plot_gradient_descent(theta, alpha=0.02)\n",
"plt.ylabel(\"$y$\", rotation=0, fontsize=18)\n",
"plt.subplot(132); plot_gradient_descent(theta, alpha=0.1, theta_path=theta_path_bgd)\n",
"plt.subplot(133); plot_gradient_descent(theta, alpha=0.5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Convert lists to numpy arrays"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:36.476582Z",
"iopub.status.busy": "2024-01-10T00:19:36.475915Z",
"iopub.status.idle": "2024-01-10T00:19:36.482264Z",
"shell.execute_reply": "2024-01-10T00:19:36.481660Z"
}
},
"outputs": [],
"source": [
"theta_path_bgd = np.array(theta_path_bgd)\n",
"theta_path_sgd = np.array(theta_path_sgd)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Algorithm trajectories"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T00:19:36.485127Z",
"iopub.status.busy": "2024-01-10T00:19:36.484696Z",
"iopub.status.idle": "2024-01-10T00:19:36.727779Z",
"shell.execute_reply": "2024-01-10T00:19:36.727175Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(2.5, 4.5, 2.3, 3.9)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAHVCAYAAACwpBmjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC6RklEQVR4nOzdd3xTZRcH8F9bulhlFGSWvUUUcLBBERBZYgFBtGxU5AVZyhBkg4iCoAzZFtllDwXZG8uQVYay2wKl0ALdyfP+cUxn0iZNmtH+vp9PtCT33jxp0/Se+5znHCellAIRERERERHZPWdbD4CIiIiIiIiMwwCOiIiIiIjIQTCAIyIiIiIichAM4IiIiIiIiBwEAzgiIiIiIiIHwQCOiIiIiIjIQTCAIyIiIiIichAM4IiIiIiIiBxELlsPwNFotVoEBwcjX758cHJysvVwiIiIiIjIRpRSePr0KUqUKAFnZ+vMjTGAM1FwcDBKly5t62EQEREREZGduHPnDkqVKmWV52IAZ6J8+fIBkB9S/vz5bTwaIiIiIiKylcjISJQuXToxRrAGBnAm0qVN5s+fnwEcERERERFZdWkVi5gQERERERE5CAZwREREREREDoIBHBERERERkYNgAEdEREREROQgGMARERERERE5CAZwREREREREDoIBHBERERERkYNgHzgrio+Ph0ajsfUwiKzGxcUFrq6uth4GERERUbbBAM4KIiMjERYWhtjYWFsPhcjq3N3d4e3tzcb3RERERBbAAC6LRUZG4t69e8ibNy+8vb3h6upq1U7tRLailEJ8fDwiIiJw7949AGAQR0RERGQmBnBZLCwsDHnz5kWpUqUYuFGO4+npiXz58uHu3bsICwtjAEdERERkJhYxyULx8fGIjY2Fl5cXgzfKsZycnODl5YXY2FjEx8fbejhEREREDo0BXBbSFSxhEQfK6XS/AyziQ0RERGQeBnBWwNk3yun4O0BERERkGQzgiIiIiIiIHITDBHAXL15Ep06dUL58eeTOnRve3t5o3Lgxtm7datT+gYGBaNOmDYoVK4a8efPipZdewo8//siULiIiIiIichgOU4Xy1q1bePr0Kfz8/FCiRAlERUVhw4YNaNeuHRYsWIB+/foZ3DcwMBD169dHpUqV8OWXXyJ37tzYuXMnBg0ahH/++QezZ8+24ishIiIiIiLKHCellLL1IDJLo9GgTp06iImJQVBQkMHt+vXrh+XLlyMkJASFChVKvL9JkyY4e/YsIiIijH7OyMhIeHl5ISIiIsOS6DExMbhx4wbKlSsHDw8Po58jJ7p27Rpmz56NvXv34tatW9BoNPD29kbx4sXx+uuvo1mzZnj//fdtPcw09u/fj2bNmqFJkybYv3+/rYeTKbr1aVn5UcDfBSIiIsqOTIkNLMVhUij1cXFxQenSpfHkyZN0t4uMjISHhwcKFCiQ4v7ixYvD09Mz6wZIRgkICEDNmjXx008/4cGDB2jQoAHef/99vPTSS7h37x5++ukn9O/fP8U+TZs2hZOTk8MGTdbC7xMRERFR9uIwKZQ6z58/R3R0NCIiIrBlyxbs3LkTXbp0SXefpk2bYs2aNejfvz+GDBmSmEIZEBCAGTNmWGnkpM/9+/fh5+eH2NhYDB06FJMmTUozQxMYGIj169fbaITZ3+XLl209BCIiIiIyksMFcEOHDsWCBQsAAM7OzujYsSPmzp2b7j59+/bFxYsXsWDBAixatAiAzN7NnTsXn3zySbr7xsbGIjY2NvHfkZGRZr6CLHT7NhAWZvhxb2/Ax8d64zHCtm3b8OzZM5QoUQLfffed3m3q1KmDOnXqWHlkOUfVqlVtPQQiIiIiMpLDpVAOHjwYu3fvxvLly/HOO+9Ao9EgLi4u3X1cXFxQoUIFtGzZEsuXL8eaNWvQtm1bDBw4EJs2bUp336lTp8LLyyvxVrp0aQu+Ggu6fRuoUgWoU8fwrUoV2c6O3L9/HwBQpEgRo7bfv38/nJyccODAAQBAs2bN4OTklHhbtmxZiu2DgoLQs2dPlClTBu7u7ihUqBDeeustrF27Nt3nCQwMhJ+fX+KarUKFCqFWrVoYPnw4bt26pXef+Ph4TJ8+HTVq1ICnpycKFy6Mjh07Gpzh2rNnDwYOHIiXX34Z3t7ecHd3R6lSpdClSxecOnVK7z5arRYLFy5EgwYNUKBAAbi6uqJo0aKoVasWBg4ciJs3b5r8fdLdp09CQgKWLFmC5s2bpxhj8+bNMWfOnHS/h0RERERkeQ5dxAQAWrRogSdPnuDEiRMGT0KnTZuG2bNn49q1a8ibN2/i/c2aNcPVq1dx69Yt5MqlfzJS3wxc6dKl7a+IyenTEqRlJDAQqF07a8diAn9/f3z00UdwcXHB77//jrfeeivd7YOCgjBt2jTs2rUL9+/fR8uWLVGsWLHEx/v06YOGDRsCALZv3w5fX1/ExMSgSpUqeOWVV/DgwQMcOHAAGo0GvXr1wuLFi9M8x4wZM/DVV19Bq9WicuXKePnllxEdHY3r16/j8uXLWLp0KXr06AEgqYhJ/fr1kTt3bhw9ehSNGzdGnjx5cPLkSdy5cwcFChTAmTNnULZs2RTPU7FiRdy5cwc1atRA6dKlkStXLgQFBeHSpUvIlSsXVq9enaZwS69evbB06VJ4eHigYcOGKFKkCMLDw/Hvv//i2rVr2LhxIzp06GDS98lQEZOIiAi0adMGhw8fhqurK+rXr48SJUogNDQUFy5cwMOHD40ufMIiJkRERJQd2aKICZSDW7BggQKggoKCDG5TunRp1a1btzT3f//99wqAunbtmtHPFxERoQCoiIiIDLeNjo5Wly5dUtHR0UYfP9MCA5UCMr4FBmb9WEzw9OlTVbJkSQVAOTk5qaZNm6qJEyeq7du3qwcPHhjcr0mTJgqA2rdvn97HQ0NDlZeXlwKgJk2apLRabeJjp06dUgULFlQA1MKFC1Pst3nzZgVAeXh4qDVr1qQ57sWLF9WlS5cS/71v3z4FQAFQr7zyigoJCUl8LDo6WrVs2VIBUP369UtzrI0bN6rw8HC99+fKlUsVLlxYRUVFJd5/69YtBUCVKlUqxfPoXLp0Sd26dSvFfRl9n5RSieNPrWPHjomv68aNGykei4+PV5s2bTJ4zNSs+rtAREREZCWmxAaW4nBr4FKLjo4GgHRbAdy/f19vw+74+HgAkiZmM1FRQDotEIxmbCEKSxWsqFoVyJ3b7MPkzZsXf/75J/z8/HDixAns378/RcXEl19+Gf3790ffvn3h4uJi9HF/+eUXREREoE6dOhg9enSKx+rWrYvRo0dj2LBhmDFjBvr27Zv42Lhx4wAAkydPRufOndMct3r16nqfz8nJCUuXLk0xy+Xh4YHx48fj999/x549e9Ls06FDB73H6tChAzp16oRVq1Zh3759aN26NYCkdNPatWuneB6datWq6T1eZpw7dw4BAQHw8PDA1q1bUbJkyRSP58qVC+3bt7fY8xERERGRcRwmgHvw4AGKFi2a4r74+HisWLECnp6eiSfWISEhiIiIQIUKFeDq6goAqFy5Mnbv3o1Hjx6hcOHCAKSH3Nq1a5EvXz5UqFDBui8muaAg41IfLaV7d8scx4KpmFWqVMHx48dx8uRJbN++HSdOnMDp06fx8OFDnD17Fp9++ik2bNiA7du3w83Nzahj6oJAPz8/vY/37t0bw4YNw7Vr1xAcHJyYGnj27Fk4Ozujd+/eJr0GHx8f1KpVK839uqDq3r17evcLDg7G9u3bERQUhIiIiMSLCRcvXgQAXLlyJTGAq1q1KvLly4cdO3Zg8uTJ6NatG8qVK2fSOI21a9cuAMC7776bJngjIiIiIttxmACuf//+iIyMROPGjVGyZEmEhoZi5cqVCAoKwsyZMxPXto0cORLLly/HjRs3EtccffXVV+jevTtef/119OvXD56enli1ahUCAwMxadKkxEDPJqpWlWDIXJcvGxec+fsDlpipyYLKha+99hpee+01ALIe68yZM5gxYwZWr16NPXv2YPbs2Rg+fLhRx9IFTIYCnAIFCqBQoUIIDw/H3bt3UaJECdz
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,5))\n",
"plt.plot(theta_path_sgd[:, 0], theta_path_sgd[:, 1], \"r-s\", linewidth=1, label=\"Stochastic\")\n",
"plt.plot(theta_path_bgd[:, 0], theta_path_bgd[:, 1], \"b-o\", linewidth=3, label=\"Batch\")\n",
"plt.legend(loc=\"upper left\", fontsize=16)\n",
"plt.xlabel(r\"$\\theta_0$\", fontsize=20)\n",
"plt.ylabel(r\"$\\theta_1$ \", fontsize=20, rotation=0)\n",
"plt.axis([2.5, 4.5, 2.3, 3.9])"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"We finally show the full trajectory for all optimisation methods, including mini-batch gradient descent (computed in exercises)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture06_Images/algorithm_trajectories.png\" alt=\"trajectories\" style=\"height: 400px;\"/>"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}