spce0038-machine-learning-w.../week6/exercises/Lecture17_EnsembleRFs_Exercises_no_solutions.ipynb
2025-02-28 11:02:51 +00:00

1 line
4.1 KiB
Plaintext

{"cells": [{"cell_type": "markdown", "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "source": ["# Exercises for Lecture 17 (Ensemble Learning and Random Forests)"]}, {"cell_type": "code", "execution_count": null, "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "outputs": [], "source": ["import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.ensemble import GradientBoostingRegressor \n", "from sklearn.model_selection import train_test_split"]}, {"cell_type": "markdown", "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "source": ["## Exercise 1: Early stopping"]}, {"cell_type": "markdown", "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "source": ["### Set up mock data"]}, {"cell_type": "code", "execution_count": null, "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "outputs": [], "source": ["# Training set: a noisy quadratic function\n", "np.random.seed(42)\n", "X = np.random.rand(100, 1) - 0.5\n", "y = 3*X[:, 0]**2 + 0.05 * np.random.randn(100)\n", "\n", "# First create test and train data\n", "X_train, X_val, y_train, y_val = train_test_split(X, y)"]}, {"cell_type": "markdown", "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "source": ["### Train with many trees"]}, {"cell_type": "code", "execution_count": null, "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "outputs": [], "source": ["from sklearn.metrics import mean_squared_error\n", "\n", "n_estimators = 300\n", "gbrt = GradientBoostingRegressor(\n", " max_depth=2, \n", " n_estimators=n_estimators, \n", " learning_rate=0.1, # Set a low learning rate here\n", " random_state=42)\n", "\n", "gbrt.fit(X_train, y_train)"]}, {"cell_type": "markdown", "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "source": ["### Compute and plot validation error for intermediate number of trees"]}, {"cell_type": "code", "execution_count": null, "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "outputs": [], "source": ["# measure MSE validation error at each stage\n", "errors = [mean_squared_error(y_val, y_pred) for y_pred in gbrt.staged_predict(X_val)]"]}, {"cell_type": "code", "execution_count": null, "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "outputs": [], "source": ["plt.figure(figsize=(11, 4))\n", "plt.plot(errors, \"b.-\")\n", "plt.axis([0, 300, 0, 0.01])\n", "plt.xlabel(\"Number of trees\")\n", "plt.title(\"Validation error\", fontsize=14)"]}, {"cell_type": "markdown", "metadata": {"editable": true, "slideshow": {"slide_type": "slide"}, "tags": []}, "source": ["### Training a better model with fewer trees\n", "\n", "- Find the best number of trees from the validation error. Show this on a plot.\n", "- Train a new GBRT using the optimal number of trees from above.\n", "- Plot predictions of the original and best models.\n"]}, {"cell_type": "code", "execution_count": null, "metadata": {"editable": true, "slideshow": {"slide_type": ""}, "tags": []}, "outputs": [], "source": ["def plot_predictions(\n", " regressors, X, y, axes, \n", " label=None, \n", " style=\"r-\", \n", " data_style=\"b.\", \n", " data_label=None):\n", " \n", " x1 = np.linspace(axes[0], axes[1], 500)\n", " \n", " y_pred = sum(\n", " regressor.predict(x1.reshape(-1, 1)) for regressor in regressors)\n", " \n", " plt.plot(X[:, 0], y, data_style, label=data_label)\n", " plt.plot(x1, y_pred, style, linewidth=2, label=label)\n", " if label or data_label:\n", " plt.legend(loc=\"upper center\", fontsize=16)\n", " plt.axis(axes)"]}], "metadata": {"celltoolbar": "Tags", "kernelspec": {"display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11"}}, "nbformat": 4, "nbformat_minor": 4}