2962 lines
195 KiB
Plaintext
2962 lines
195 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
}
|
||
|
},
|
||
|
"source": [
|
||
|
"# Lecture 16: Decision trees"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"[Run in colab](https://colab.research.google.com/drive/1P9IoqXN9dbjJ3TN50wa8wwDdvn9P6hX7)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:19.492309Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:19.492086Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:19.498527Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:19.497958Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Last executed: 2025-02-27 23:21:19\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import datetime\n",
|
||
|
"now = datetime.datetime.now()\n",
|
||
|
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Considered conceptually as a *flow diagram* or tree of decisions based on inspecting properties of data-set.\n",
|
||
|
"\n",
|
||
|
"- Can perform both classification and regression.\n",
|
||
|
"- A fundamental component of random forests (a powerful machine learning algorithm covered in the next lecture). \n",
|
||
|
"- We will learn how to visualise and make predictions using Decision Trees.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## Conceptual example"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/DecisionTree.jpg\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"[[Image source](https://inside-machinelearning.com/en/decision-tree-and-hyperparameters/)]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## Walk-through of decision tree\n",
|
||
|
"\n",
|
||
|
"Let's consider an illustration using the Iris Data set (introduced in Lecture 3)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Images of different Iris species\n",
|
||
|
"\n",
|
||
|
"#### Iris Setosa\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture03_Images/iris_setosa.jpg\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"#### Iris Versicolor\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture03_Images/iris_versicolor.jpg\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"#### Iris Virginica\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture03_Images/iris_virginica.jpg\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"[[Image source](https://github.com/jakevdp/sklearn_tutorial)]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Load feature matrix, where each row correpsonds to an observed (*sampled*) flower, with a number of *features*, with corresponding target vector.\n",
|
||
|
"\n",
|
||
|
"Consider two features only for now (petal length and width)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:19.532095Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:19.531889Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.377758Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.377159Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-1 {\n",
|
||
|
" /* Definition of color scheme common for light and dark mode */\n",
|
||
|
" --sklearn-color-text: #000;\n",
|
||
|
" --sklearn-color-text-muted: #666;\n",
|
||
|
" --sklearn-color-line: gray;\n",
|
||
|
" /* Definition of color scheme for unfitted estimators */\n",
|
||
|
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
|
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
|
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
|
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
|
" /* Definition of color scheme for fitted estimators */\n",
|
||
|
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
|
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
|
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
|
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
|
"\n",
|
||
|
" /* Specific color for light theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-icon: #696969;\n",
|
||
|
"\n",
|
||
|
" @media (prefers-color-scheme: dark) {\n",
|
||
|
" /* Redefinition of color scheme for dark theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-icon: #878787;\n",
|
||
|
" }\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 pre {\n",
|
||
|
" padding: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 input.sk-hidden--visually {\n",
|
||
|
" border: 0;\n",
|
||
|
" clip: rect(1px 1px 1px 1px);\n",
|
||
|
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
|
" height: 1px;\n",
|
||
|
" margin: -1px;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" padding: 0;\n",
|
||
|
" position: absolute;\n",
|
||
|
" width: 1px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
|
||
|
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
|
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" padding-bottom: 0.4em;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-container {\n",
|
||
|
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
|
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
|
" so we also need the `!important` here to be able to override the\n",
|
||
|
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
|
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
|
" display: inline-block !important;\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
|
||
|
" display: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-parallel-item,\n",
|
||
|
"div.sk-serial,\n",
|
||
|
"div.sk-item {\n",
|
||
|
" /* draw centered vertical line to link estimators */\n",
|
||
|
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
|
" background-size: 2px 100%;\n",
|
||
|
" background-repeat: no-repeat;\n",
|
||
|
" background-position: center center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Parallel-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-parallel-item::after {\n",
|
||
|
" content: \"\";\n",
|
||
|
" width: 100%;\n",
|
||
|
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
|
" flex-grow: 1;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-parallel {\n",
|
||
|
" display: flex;\n",
|
||
|
" align-items: stretch;\n",
|
||
|
" justify-content: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-parallel-item {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
|
||
|
" align-self: flex-end;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
|
||
|
" align-self: flex-start;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
|
||
|
" width: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Serial-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-serial {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
" align-items: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" padding-right: 1em;\n",
|
||
|
" padding-left: 1em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
|
"clickable and can be expanded/collapsed.\n",
|
||
|
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
|
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
|
"*/\n",
|
||
|
"\n",
|
||
|
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-toggleable {\n",
|
||
|
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
|
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable label */\n",
|
||
|
"#sk-container-id-1 label.sk-toggleable__label {\n",
|
||
|
" cursor: pointer;\n",
|
||
|
" display: flex;\n",
|
||
|
" width: 100%;\n",
|
||
|
" margin-bottom: 0;\n",
|
||
|
" padding: 0.5em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" text-align: center;\n",
|
||
|
" align-items: start;\n",
|
||
|
" justify-content: space-between;\n",
|
||
|
" gap: 0.5em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 label.sk-toggleable__label .caption {\n",
|
||
|
" font-size: 0.6rem;\n",
|
||
|
" font-weight: lighter;\n",
|
||
|
" color: var(--sklearn-color-text-muted);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
|
||
|
" /* Arrow on the left of the label */\n",
|
||
|
" content: \"▸\";\n",
|
||
|
" float: left;\n",
|
||
|
" margin-right: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-icon);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable content - dropdown */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-toggleable__content {\n",
|
||
|
" max-height: 0;\n",
|
||
|
" max-width: 0;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" text-align: left;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
|
||
|
" margin: 0.2em;\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
|
" /* Expand drop-down */\n",
|
||
|
" max-height: 200px;\n",
|
||
|
" max-width: 100%;\n",
|
||
|
" overflow: auto;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
|
" content: \"▾\";\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific style */\n",
|
||
|
"\n",
|
||
|
"/* Colorize estimator box */\n",
|
||
|
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
|
||
|
"#sk-container-id-1 div.sk-label label {\n",
|
||
|
" /* The background is the default theme color */\n",
|
||
|
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover, darken the color of the background */\n",
|
||
|
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Label box, darken color on hover, fitted */\n",
|
||
|
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator label */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-label label {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" font-weight: bold;\n",
|
||
|
" display: inline-block;\n",
|
||
|
" line-height: 1.2em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-label-container {\n",
|
||
|
" text-align: center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific */\n",
|
||
|
"#sk-container-id-1 div.sk-estimator {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" margin-bottom: 0.5em;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-estimator.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* on hover */\n",
|
||
|
"#sk-container-id-1 div.sk-estimator:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
|
"\n",
|
||
|
"/* Common style for \"i\" and \"?\" */\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link,\n",
|
||
|
"a:link.sk-estimator-doc-link,\n",
|
||
|
"a:visited.sk-estimator-doc-link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: smaller;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1em;\n",
|
||
|
" height: 1em;\n",
|
||
|
" width: 1em;\n",
|
||
|
" text-decoration: none !important;\n",
|
||
|
" margin-left: 0.5em;\n",
|
||
|
" text-align: center;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted,\n",
|
||
|
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
|
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
|
".sk-estimator-doc-link span {\n",
|
||
|
" display: none;\n",
|
||
|
" z-index: 9999;\n",
|
||
|
" position: relative;\n",
|
||
|
" font-weight: normal;\n",
|
||
|
" right: .2ex;\n",
|
||
|
" padding: .5ex;\n",
|
||
|
" margin: .5ex;\n",
|
||
|
" width: min-content;\n",
|
||
|
" min-width: 20ex;\n",
|
||
|
" max-width: 50ex;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted span {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
|
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link:hover span {\n",
|
||
|
" display: block;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 a.estimator_doc_link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: 1rem;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1rem;\n",
|
||
|
" height: 1rem;\n",
|
||
|
" width: 1rem;\n",
|
||
|
" text-decoration: none;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(max_depth=2, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(max_depth=2, random_state=42)</pre></div> </div></div></div></div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"DecisionTreeClassifier(max_depth=2, random_state=42)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.datasets import load_iris\n",
|
||
|
"from sklearn.tree import DecisionTreeClassifier\n",
|
||
|
"\n",
|
||
|
"iris = load_iris(as_frame=True)\n",
|
||
|
"X_iris = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
|
||
|
"y_iris = iris.target\n",
|
||
|
"\n",
|
||
|
"tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)\n",
|
||
|
"tree_clf.fit(X_iris, y_iris)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.379798Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.379575Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.569651Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.568896Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# We want to visualise the actual flow diagram of the tree, for this we can use graphviz\n",
|
||
|
"from sklearn.tree import export_graphviz\n",
|
||
|
"\n",
|
||
|
"export_graphviz(tree_clf, \n",
|
||
|
" out_file = './iris_tree.dot', \n",
|
||
|
" feature_names = iris.feature_names[ 2:], \n",
|
||
|
" class_names = iris.target_names, \n",
|
||
|
" rounded = True, \n",
|
||
|
" filled = True)\n",
|
||
|
"\n",
|
||
|
"#creates a dot file :( so need to convert to something more sensible\n",
|
||
|
"! dot -Tpng ./iris_tree.dot -o ./iris_tree.png"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"To run dot locally you will need to install [graphviz](https://graphviz.org/).\n",
|
||
|
"\n",
|
||
|
"You can install on Mac using Homebrew:\n",
|
||
|
"```bash\n",
|
||
|
"brew install graphviz\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"You can install on Ubuntu using apt:\n",
|
||
|
"```bash\n",
|
||
|
"sudo apt install graphviz\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"Installation instructions for other systems are available [here](https://graphviz.org/download/)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Decision tree for Iris classification (depth 2)\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree.png\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Tree consists of a number of nodes.\n",
|
||
|
"- Top node is the _root node_.\n",
|
||
|
"- Intermediate _split nodes_.\n",
|
||
|
"- Lower nodes are _leaf nodes_.\n",
|
||
|
"\n",
|
||
|
"Decisions based on *features* and *thresholds*.\n",
|
||
|
"\n",
|
||
|
"Navigate tree to make predictions."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Interpreting node outputs\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"Arguments in the nodes are: \n",
|
||
|
"- Top argument shows the _threshold_ upon which the classification division was made.\n",
|
||
|
"- ```gini``` (see next slides) is a quantitative measure of impurity.\n",
|
||
|
"- ```samples``` denotes the number of training instances that satisfy the criteria.\n",
|
||
|
"- ```values``` denotes the number of training instances per class that satisfy the criteria.\n",
|
||
|
"- ```class``` prediction for the node."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Decision boundaries"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/decision_tree_decision_boundaries_plot.png\" alt=\"data-layout\" width=\"600\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"We set ```max_depth=2```, so algorithm stopped after two divisions. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Estimating class probabilities"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Also want to know the _probability_ that an instance $i$ belongs to class $k$. \n",
|
||
|
"\n",
|
||
|
"Class probability founds by finding the leaf node for instance $i$, then returns ratio of training instances of class $k$ in this node.\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"#### Example \n",
|
||
|
"For case where flower has petals=5cm long and 1.5cm wide\n",
|
||
|
"corresponding leaf node is at depth-2 left node. \n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/decision_tree_decision_boundaries_plot.png\" alt=\"data-layout\" width=\"600\" style=\"display:block; margin:auto\"/>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node2.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"So probabilities are: 0% (Setosa), 49/54=90.7% (Versicolor), 5/54=9.3% (Virginica)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.572369Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.572178Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.577551Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.576948Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([[0. , 0.907, 0.093]])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tree_clf.predict_proba([[5, 1.5]]).round(3)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## Quality measures"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Gini impurity\n",
|
||
|
"\n",
|
||
|
"Gini impurity is defined by\n",
|
||
|
"$$\n",
|
||
|
"G_i=1-\\sum_{k=1}^{n}p^2_{i,k} ,\n",
|
||
|
"$$\n",
|
||
|
"where $p_{i,k}$ is the ratio of class $k$ instances among training instances in the $i^{\\rm th}$ node. \n",
|
||
|
"\n",
|
||
|
"$G_i=0$ means the sample is 100% _pure_ i.e. all instances are in a single class. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"#### Gini example calculation\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node2.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"$$\n",
|
||
|
"G_i = 1 - (0/54)^2 - (49/54)^2 - (5/54)^2= 0.168\n",
|
||
|
"$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Entropy \n",
|
||
|
"\n",
|
||
|
"Alternative to Gini is to use entropy as the purity measure \n",
|
||
|
"$$\n",
|
||
|
"H_i=-\\sum_{k=1}^n p_{i,k}\\log_2(p_{i,k}),\n",
|
||
|
"$$\n",
|
||
|
"for $p_{i,k}\\not=0$. \n",
|
||
|
"\n",
|
||
|
"Measures the information content of variable (number of bits required to encode)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"#### Entropy example calculation\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node2.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
|
||
|
"\n",
|
||
|
"$$\n",
|
||
|
"H_i = - (49/54) \\log_2(49/54) - (5/54) \\log_2(5/54)= 0.445\n",
|
||
|
"$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Does it make a difference what quality measure is used? \n",
|
||
|
"\n",
|
||
|
"Not usually, although Gini tends to isolate the most frequent classes, and entropy leads to more \"balanced\" trees. Entropy is slightly more expensive to compute."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Decision tree using entropy criterion"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.579634Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.579461Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.585935Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.585283Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-2 {\n",
|
||
|
" /* Definition of color scheme common for light and dark mode */\n",
|
||
|
" --sklearn-color-text: #000;\n",
|
||
|
" --sklearn-color-text-muted: #666;\n",
|
||
|
" --sklearn-color-line: gray;\n",
|
||
|
" /* Definition of color scheme for unfitted estimators */\n",
|
||
|
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
|
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
|
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
|
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
|
" /* Definition of color scheme for fitted estimators */\n",
|
||
|
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
|
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
|
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
|
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
|
"\n",
|
||
|
" /* Specific color for light theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-icon: #696969;\n",
|
||
|
"\n",
|
||
|
" @media (prefers-color-scheme: dark) {\n",
|
||
|
" /* Redefinition of color scheme for dark theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-icon: #878787;\n",
|
||
|
" }\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 pre {\n",
|
||
|
" padding: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 input.sk-hidden--visually {\n",
|
||
|
" border: 0;\n",
|
||
|
" clip: rect(1px 1px 1px 1px);\n",
|
||
|
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
|
" height: 1px;\n",
|
||
|
" margin: -1px;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" padding: 0;\n",
|
||
|
" position: absolute;\n",
|
||
|
" width: 1px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-dashed-wrapped {\n",
|
||
|
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
|
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" padding-bottom: 0.4em;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-container {\n",
|
||
|
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
|
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
|
" so we also need the `!important` here to be able to override the\n",
|
||
|
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
|
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
|
" display: inline-block !important;\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-text-repr-fallback {\n",
|
||
|
" display: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-parallel-item,\n",
|
||
|
"div.sk-serial,\n",
|
||
|
"div.sk-item {\n",
|
||
|
" /* draw centered vertical line to link estimators */\n",
|
||
|
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
|
" background-size: 2px 100%;\n",
|
||
|
" background-repeat: no-repeat;\n",
|
||
|
" background-position: center center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Parallel-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-parallel-item::after {\n",
|
||
|
" content: \"\";\n",
|
||
|
" width: 100%;\n",
|
||
|
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
|
" flex-grow: 1;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-parallel {\n",
|
||
|
" display: flex;\n",
|
||
|
" align-items: stretch;\n",
|
||
|
" justify-content: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-parallel-item {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
|
||
|
" align-self: flex-end;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
|
||
|
" align-self: flex-start;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
|
||
|
" width: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Serial-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-serial {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
" align-items: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" padding-right: 1em;\n",
|
||
|
" padding-left: 1em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
|
"clickable and can be expanded/collapsed.\n",
|
||
|
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
|
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
|
"*/\n",
|
||
|
"\n",
|
||
|
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-toggleable {\n",
|
||
|
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
|
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable label */\n",
|
||
|
"#sk-container-id-2 label.sk-toggleable__label {\n",
|
||
|
" cursor: pointer;\n",
|
||
|
" display: flex;\n",
|
||
|
" width: 100%;\n",
|
||
|
" margin-bottom: 0;\n",
|
||
|
" padding: 0.5em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" text-align: center;\n",
|
||
|
" align-items: start;\n",
|
||
|
" justify-content: space-between;\n",
|
||
|
" gap: 0.5em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 label.sk-toggleable__label .caption {\n",
|
||
|
" font-size: 0.6rem;\n",
|
||
|
" font-weight: lighter;\n",
|
||
|
" color: var(--sklearn-color-text-muted);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
|
||
|
" /* Arrow on the left of the label */\n",
|
||
|
" content: \"▸\";\n",
|
||
|
" float: left;\n",
|
||
|
" margin-right: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-icon);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable content - dropdown */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-toggleable__content {\n",
|
||
|
" max-height: 0;\n",
|
||
|
" max-width: 0;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" text-align: left;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-toggleable__content pre {\n",
|
||
|
" margin: 0.2em;\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
|
" /* Expand drop-down */\n",
|
||
|
" max-height: 200px;\n",
|
||
|
" max-width: 100%;\n",
|
||
|
" overflow: auto;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
|
" content: \"▾\";\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific style */\n",
|
||
|
"\n",
|
||
|
"/* Colorize estimator box */\n",
|
||
|
"#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
|
||
|
"#sk-container-id-2 div.sk-label label {\n",
|
||
|
" /* The background is the default theme color */\n",
|
||
|
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover, darken the color of the background */\n",
|
||
|
"#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Label box, darken color on hover, fitted */\n",
|
||
|
"#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator label */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-label label {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" font-weight: bold;\n",
|
||
|
" display: inline-block;\n",
|
||
|
" line-height: 1.2em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-label-container {\n",
|
||
|
" text-align: center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific */\n",
|
||
|
"#sk-container-id-2 div.sk-estimator {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" margin-bottom: 0.5em;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-estimator.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* on hover */\n",
|
||
|
"#sk-container-id-2 div.sk-estimator:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
|
"\n",
|
||
|
"/* Common style for \"i\" and \"?\" */\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link,\n",
|
||
|
"a:link.sk-estimator-doc-link,\n",
|
||
|
"a:visited.sk-estimator-doc-link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: smaller;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1em;\n",
|
||
|
" height: 1em;\n",
|
||
|
" width: 1em;\n",
|
||
|
" text-decoration: none !important;\n",
|
||
|
" margin-left: 0.5em;\n",
|
||
|
" text-align: center;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted,\n",
|
||
|
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
|
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
|
".sk-estimator-doc-link span {\n",
|
||
|
" display: none;\n",
|
||
|
" z-index: 9999;\n",
|
||
|
" position: relative;\n",
|
||
|
" font-weight: normal;\n",
|
||
|
" right: .2ex;\n",
|
||
|
" padding: .5ex;\n",
|
||
|
" margin: .5ex;\n",
|
||
|
" width: min-content;\n",
|
||
|
" min-width: 20ex;\n",
|
||
|
" max-width: 50ex;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted span {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
|
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link:hover span {\n",
|
||
|
" display: block;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 a.estimator_doc_link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: 1rem;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1rem;\n",
|
||
|
" height: 1rem;\n",
|
||
|
" width: 1rem;\n",
|
||
|
" text-decoration: none;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 a.estimator_doc_link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"#sk-container-id-2 a.estimator_doc_link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(criterion='entropy', max_depth=2)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(criterion='entropy', max_depth=2)</pre></div> </div></div></div></div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"DecisionTreeClassifier(criterion='entropy', max_depth=2)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Redo the first example but use entropy instead\n",
|
||
|
"tree_clf = DecisionTreeClassifier(max_depth = 2,criterion='entropy') #making a decision tree of depth 2 from the data \n",
|
||
|
"tree_clf.fit(X_iris, y_iris)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.587706Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.587538Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.733060Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.732336Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"export_graphviz(tree_clf, \n",
|
||
|
" out_file = './iris_tree_entropy.dot', \n",
|
||
|
" feature_names = iris.feature_names[ 2:], \n",
|
||
|
" class_names = iris.target_names, \n",
|
||
|
" rounded = True, \n",
|
||
|
" filled = True)\n",
|
||
|
"! dot -Tpng ./iris_tree_entropy.dot -o ./iris_tree_entropy.png"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"#### Entropy tree\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_entropy.png\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"#### Gini tree (for reference)\n",
|
||
|
"\n",
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree.png\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## CART training algorithm\n",
|
||
|
"\n",
|
||
|
"Classification And Regression Tree (CART) algorihtm can be used to train decision trees -- also called *growing trees* algorithm (used by SciKit Learn)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"- Splits the sample into two subsets using a single feature $k$ at threshold $t_k$\n",
|
||
|
"- Chooses feature $k$ and threshold $t_k$ by finding pair that produces purest subset, weighted by their size.\n",
|
||
|
"\n",
|
||
|
"Cost function minimized for each split:\n",
|
||
|
"\n",
|
||
|
"$$\n",
|
||
|
"J(k,t_k)=\\frac{m_\\text{left}}{m}G_\\text{left}+\\frac{m_\\text{right}}{m}G_\\text{right} ,\n",
|
||
|
"$$\n",
|
||
|
"\n",
|
||
|
"where $G_\\text{left/right}$ measures the impurity of the left/right subset and $m_\\text{left/right}$ is the number of instances in the left/right subset."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Note that the CART algorithm:\n",
|
||
|
"- Only splits data in two at each stage, i.e. is binary.\n",
|
||
|
"- Is a greedy algorithm. It searches for the optimal split at each level, then repeats for subsequent levels. There is no guarantee that the overall optimal tree is found. Nevertheless, usually produces a tree that is reasonably good."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## Regularisation and hyperparameters\n",
|
||
|
"\n",
|
||
|
"Decision Trees are **non-parametric** classification algorithms since the number of parameters is not determined prior to training.\n",
|
||
|
"\n",
|
||
|
"Tends to overfit if not careful. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Need to regularise the problem. This can be done by the restricting the complexity of the tree, for example though the following SciKit Learn parameters.\n",
|
||
|
"\n",
|
||
|
"- `max_depth`: maximum depth of the tree.\n",
|
||
|
"- `max_features` maximum number of features used in splitting at each node.\n",
|
||
|
"- `max_leaf_nodes` maximum number of leaf nodes.\n",
|
||
|
"- `min_samples_split`: mimimum number of samples a node must have before it can be split.\n",
|
||
|
"- `min_samples_leaf`: minimum number a leaf can have to be created.\n",
|
||
|
"- `min_weight_fraction`: Same as `min_samples_leaf` but expressed as a fraction of total samples.\n",
|
||
|
"\n",
|
||
|
"Generally, increasing ```min_*``` hyperparameters or reducing ```max_*``` hyperparameters will regularise the model."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Other algorithms _prune_, i.e. make a (relatively) unrestricted tree then remove statistically insignificant nodes."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## Decision trees for regression \n",
|
||
|
"\n",
|
||
|
"Decision trees can also be used for regression tasks."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Train regression model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.735593Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.735405Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.743106Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.742456Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-3 {\n",
|
||
|
" /* Definition of color scheme common for light and dark mode */\n",
|
||
|
" --sklearn-color-text: #000;\n",
|
||
|
" --sklearn-color-text-muted: #666;\n",
|
||
|
" --sklearn-color-line: gray;\n",
|
||
|
" /* Definition of color scheme for unfitted estimators */\n",
|
||
|
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
|
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
|
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
|
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
|
" /* Definition of color scheme for fitted estimators */\n",
|
||
|
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
|
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
|
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
|
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
|
"\n",
|
||
|
" /* Specific color for light theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-icon: #696969;\n",
|
||
|
"\n",
|
||
|
" @media (prefers-color-scheme: dark) {\n",
|
||
|
" /* Redefinition of color scheme for dark theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-icon: #878787;\n",
|
||
|
" }\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 pre {\n",
|
||
|
" padding: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 input.sk-hidden--visually {\n",
|
||
|
" border: 0;\n",
|
||
|
" clip: rect(1px 1px 1px 1px);\n",
|
||
|
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
|
" height: 1px;\n",
|
||
|
" margin: -1px;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" padding: 0;\n",
|
||
|
" position: absolute;\n",
|
||
|
" width: 1px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-dashed-wrapped {\n",
|
||
|
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
|
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" padding-bottom: 0.4em;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-container {\n",
|
||
|
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
|
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
|
" so we also need the `!important` here to be able to override the\n",
|
||
|
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
|
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
|
" display: inline-block !important;\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-text-repr-fallback {\n",
|
||
|
" display: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-parallel-item,\n",
|
||
|
"div.sk-serial,\n",
|
||
|
"div.sk-item {\n",
|
||
|
" /* draw centered vertical line to link estimators */\n",
|
||
|
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
|
" background-size: 2px 100%;\n",
|
||
|
" background-repeat: no-repeat;\n",
|
||
|
" background-position: center center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Parallel-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-parallel-item::after {\n",
|
||
|
" content: \"\";\n",
|
||
|
" width: 100%;\n",
|
||
|
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
|
" flex-grow: 1;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-parallel {\n",
|
||
|
" display: flex;\n",
|
||
|
" align-items: stretch;\n",
|
||
|
" justify-content: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-parallel-item {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-parallel-item:first-child::after {\n",
|
||
|
" align-self: flex-end;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-parallel-item:last-child::after {\n",
|
||
|
" align-self: flex-start;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-parallel-item:only-child::after {\n",
|
||
|
" width: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Serial-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-serial {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
" align-items: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" padding-right: 1em;\n",
|
||
|
" padding-left: 1em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
|
"clickable and can be expanded/collapsed.\n",
|
||
|
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
|
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
|
"*/\n",
|
||
|
"\n",
|
||
|
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-toggleable {\n",
|
||
|
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
|
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable label */\n",
|
||
|
"#sk-container-id-3 label.sk-toggleable__label {\n",
|
||
|
" cursor: pointer;\n",
|
||
|
" display: flex;\n",
|
||
|
" width: 100%;\n",
|
||
|
" margin-bottom: 0;\n",
|
||
|
" padding: 0.5em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" text-align: center;\n",
|
||
|
" align-items: start;\n",
|
||
|
" justify-content: space-between;\n",
|
||
|
" gap: 0.5em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 label.sk-toggleable__label .caption {\n",
|
||
|
" font-size: 0.6rem;\n",
|
||
|
" font-weight: lighter;\n",
|
||
|
" color: var(--sklearn-color-text-muted);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 label.sk-toggleable__label-arrow:before {\n",
|
||
|
" /* Arrow on the left of the label */\n",
|
||
|
" content: \"▸\";\n",
|
||
|
" float: left;\n",
|
||
|
" margin-right: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-icon);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable content - dropdown */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-toggleable__content {\n",
|
||
|
" max-height: 0;\n",
|
||
|
" max-width: 0;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" text-align: left;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-toggleable__content.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-toggleable__content pre {\n",
|
||
|
" margin: 0.2em;\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-toggleable__content.fitted pre {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
|
" /* Expand drop-down */\n",
|
||
|
" max-height: 200px;\n",
|
||
|
" max-width: 100%;\n",
|
||
|
" overflow: auto;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
|
" content: \"▾\";\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific style */\n",
|
||
|
"\n",
|
||
|
"/* Colorize estimator box */\n",
|
||
|
"#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-label label.sk-toggleable__label,\n",
|
||
|
"#sk-container-id-3 div.sk-label label {\n",
|
||
|
" /* The background is the default theme color */\n",
|
||
|
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover, darken the color of the background */\n",
|
||
|
"#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Label box, darken color on hover, fitted */\n",
|
||
|
"#sk-container-id-3 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator label */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-label label {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" font-weight: bold;\n",
|
||
|
" display: inline-block;\n",
|
||
|
" line-height: 1.2em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-label-container {\n",
|
||
|
" text-align: center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific */\n",
|
||
|
"#sk-container-id-3 div.sk-estimator {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" margin-bottom: 0.5em;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-estimator.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* on hover */\n",
|
||
|
"#sk-container-id-3 div.sk-estimator:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 div.sk-estimator.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
|
"\n",
|
||
|
"/* Common style for \"i\" and \"?\" */\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link,\n",
|
||
|
"a:link.sk-estimator-doc-link,\n",
|
||
|
"a:visited.sk-estimator-doc-link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: smaller;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1em;\n",
|
||
|
" height: 1em;\n",
|
||
|
" width: 1em;\n",
|
||
|
" text-decoration: none !important;\n",
|
||
|
" margin-left: 0.5em;\n",
|
||
|
" text-align: center;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted,\n",
|
||
|
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
|
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
|
".sk-estimator-doc-link span {\n",
|
||
|
" display: none;\n",
|
||
|
" z-index: 9999;\n",
|
||
|
" position: relative;\n",
|
||
|
" font-weight: normal;\n",
|
||
|
" right: .2ex;\n",
|
||
|
" padding: .5ex;\n",
|
||
|
" margin: .5ex;\n",
|
||
|
" width: min-content;\n",
|
||
|
" min-width: 20ex;\n",
|
||
|
" max-width: 50ex;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted span {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
|
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link:hover span {\n",
|
||
|
" display: block;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 a.estimator_doc_link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: 1rem;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1rem;\n",
|
||
|
" height: 1rem;\n",
|
||
|
" width: 1rem;\n",
|
||
|
" text-decoration: none;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 a.estimator_doc_link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"#sk-container-id-3 a.estimator_doc_link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-3 a.estimator_doc_link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeRegressor(max_depth=2, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeRegressor</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeRegressor.html\">?<span>Documentation for DecisionTreeRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeRegressor(max_depth=2, random_state=42)</pre></div> </div></div></div></div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"DecisionTreeRegressor(max_depth=2, random_state=42)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"from sklearn.tree import DecisionTreeRegressor\n",
|
||
|
"\n",
|
||
|
"np.random.seed(42)\n",
|
||
|
"X_quad = np.random.rand(200, 1) - 0.5 # a single random input feature\n",
|
||
|
"y_quad = X_quad ** 2 + 0.025 * np.random.randn(200, 1)\n",
|
||
|
"\n",
|
||
|
"tree_reg = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
|
||
|
"tree_reg.fit(X_quad, y_quad)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Visualise tree"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.744939Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.744742Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.896284Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.895553Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"export_graphviz(\n",
|
||
|
" tree_reg,\n",
|
||
|
" out_file=\"./regression_tree.dot\",\n",
|
||
|
" feature_names=[\"x1\"],\n",
|
||
|
" rounded=True,\n",
|
||
|
" filled=True\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"#creates a dot file :( so need to convert to something more sensible\n",
|
||
|
"! dot -Tpng ./regression_tree.dot -o ./regression_tree.png"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/regression_tree.png\" alt=\"data-layout\" width=\"700\" style=\"display:block; margin:auto\"/>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Train a deeper model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.898435Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.898251Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:20.904714Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:20.904136Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<style>#sk-container-id-4 {\n",
|
||
|
" /* Definition of color scheme common for light and dark mode */\n",
|
||
|
" --sklearn-color-text: #000;\n",
|
||
|
" --sklearn-color-text-muted: #666;\n",
|
||
|
" --sklearn-color-line: gray;\n",
|
||
|
" /* Definition of color scheme for unfitted estimators */\n",
|
||
|
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
|
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
|
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
|
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
|
" /* Definition of color scheme for fitted estimators */\n",
|
||
|
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
|
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
|
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
|
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
|
"\n",
|
||
|
" /* Specific color for light theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
|
" --sklearn-color-icon: #696969;\n",
|
||
|
"\n",
|
||
|
" @media (prefers-color-scheme: dark) {\n",
|
||
|
" /* Redefinition of color scheme for dark theme */\n",
|
||
|
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
|
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
|
" --sklearn-color-icon: #878787;\n",
|
||
|
" }\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 pre {\n",
|
||
|
" padding: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 input.sk-hidden--visually {\n",
|
||
|
" border: 0;\n",
|
||
|
" clip: rect(1px 1px 1px 1px);\n",
|
||
|
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
|
" height: 1px;\n",
|
||
|
" margin: -1px;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" padding: 0;\n",
|
||
|
" position: absolute;\n",
|
||
|
" width: 1px;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-dashed-wrapped {\n",
|
||
|
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
|
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" padding-bottom: 0.4em;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-container {\n",
|
||
|
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
|
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
|
" so we also need the `!important` here to be able to override the\n",
|
||
|
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
|
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
|
" display: inline-block !important;\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-text-repr-fallback {\n",
|
||
|
" display: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-parallel-item,\n",
|
||
|
"div.sk-serial,\n",
|
||
|
"div.sk-item {\n",
|
||
|
" /* draw centered vertical line to link estimators */\n",
|
||
|
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
|
" background-size: 2px 100%;\n",
|
||
|
" background-repeat: no-repeat;\n",
|
||
|
" background-position: center center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Parallel-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-parallel-item::after {\n",
|
||
|
" content: \"\";\n",
|
||
|
" width: 100%;\n",
|
||
|
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
|
" flex-grow: 1;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-parallel {\n",
|
||
|
" display: flex;\n",
|
||
|
" align-items: stretch;\n",
|
||
|
" justify-content: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" position: relative;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-parallel-item {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-parallel-item:first-child::after {\n",
|
||
|
" align-self: flex-end;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-parallel-item:last-child::after {\n",
|
||
|
" align-self: flex-start;\n",
|
||
|
" width: 50%;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-parallel-item:only-child::after {\n",
|
||
|
" width: 0;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Serial-specific style estimator block */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-serial {\n",
|
||
|
" display: flex;\n",
|
||
|
" flex-direction: column;\n",
|
||
|
" align-items: center;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" padding-right: 1em;\n",
|
||
|
" padding-left: 1em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
|
"clickable and can be expanded/collapsed.\n",
|
||
|
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
|
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
|
"*/\n",
|
||
|
"\n",
|
||
|
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-toggleable {\n",
|
||
|
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
|
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable label */\n",
|
||
|
"#sk-container-id-4 label.sk-toggleable__label {\n",
|
||
|
" cursor: pointer;\n",
|
||
|
" display: flex;\n",
|
||
|
" width: 100%;\n",
|
||
|
" margin-bottom: 0;\n",
|
||
|
" padding: 0.5em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" text-align: center;\n",
|
||
|
" align-items: start;\n",
|
||
|
" justify-content: space-between;\n",
|
||
|
" gap: 0.5em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 label.sk-toggleable__label .caption {\n",
|
||
|
" font-size: 0.6rem;\n",
|
||
|
" font-weight: lighter;\n",
|
||
|
" color: var(--sklearn-color-text-muted);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 label.sk-toggleable__label-arrow:before {\n",
|
||
|
" /* Arrow on the left of the label */\n",
|
||
|
" content: \"▸\";\n",
|
||
|
" float: left;\n",
|
||
|
" margin-right: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-icon);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Toggleable content - dropdown */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-toggleable__content {\n",
|
||
|
" max-height: 0;\n",
|
||
|
" max-width: 0;\n",
|
||
|
" overflow: hidden;\n",
|
||
|
" text-align: left;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-toggleable__content.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-toggleable__content pre {\n",
|
||
|
" margin: 0.2em;\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-toggleable__content.fitted pre {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
|
" /* Expand drop-down */\n",
|
||
|
" max-height: 200px;\n",
|
||
|
" max-width: 100%;\n",
|
||
|
" overflow: auto;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
|
" content: \"▾\";\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific style */\n",
|
||
|
"\n",
|
||
|
"/* Colorize estimator box */\n",
|
||
|
"#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-label label.sk-toggleable__label,\n",
|
||
|
"#sk-container-id-4 div.sk-label label {\n",
|
||
|
" /* The background is the default theme color */\n",
|
||
|
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover, darken the color of the background */\n",
|
||
|
"#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Label box, darken color on hover, fitted */\n",
|
||
|
"#sk-container-id-4 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator label */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-label label {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" font-weight: bold;\n",
|
||
|
" display: inline-block;\n",
|
||
|
" line-height: 1.2em;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-label-container {\n",
|
||
|
" text-align: center;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Estimator-specific */\n",
|
||
|
"#sk-container-id-4 div.sk-estimator {\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
|
" border-radius: 0.25em;\n",
|
||
|
" box-sizing: border-box;\n",
|
||
|
" margin-bottom: 0.5em;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-estimator.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* on hover */\n",
|
||
|
"#sk-container-id-4 div.sk-estimator:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 div.sk-estimator.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
|
"\n",
|
||
|
"/* Common style for \"i\" and \"?\" */\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link,\n",
|
||
|
"a:link.sk-estimator-doc-link,\n",
|
||
|
"a:visited.sk-estimator-doc-link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: smaller;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1em;\n",
|
||
|
" height: 1em;\n",
|
||
|
" width: 1em;\n",
|
||
|
" text-decoration: none !important;\n",
|
||
|
" margin-left: 0.5em;\n",
|
||
|
" text-align: center;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted,\n",
|
||
|
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
|
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
|
".sk-estimator-doc-link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover,\n",
|
||
|
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
|
".sk-estimator-doc-link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
|
".sk-estimator-doc-link span {\n",
|
||
|
" display: none;\n",
|
||
|
" z-index: 9999;\n",
|
||
|
" position: relative;\n",
|
||
|
" font-weight: normal;\n",
|
||
|
" right: .2ex;\n",
|
||
|
" padding: .5ex;\n",
|
||
|
" margin: .5ex;\n",
|
||
|
" width: min-content;\n",
|
||
|
" min-width: 20ex;\n",
|
||
|
" max-width: 50ex;\n",
|
||
|
" color: var(--sklearn-color-text);\n",
|
||
|
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
|
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link.fitted span {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
|
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
".sk-estimator-doc-link:hover span {\n",
|
||
|
" display: block;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 a.estimator_doc_link {\n",
|
||
|
" float: right;\n",
|
||
|
" font-size: 1rem;\n",
|
||
|
" line-height: 1em;\n",
|
||
|
" font-family: monospace;\n",
|
||
|
" background-color: var(--sklearn-color-background);\n",
|
||
|
" border-radius: 1rem;\n",
|
||
|
" height: 1rem;\n",
|
||
|
" width: 1rem;\n",
|
||
|
" text-decoration: none;\n",
|
||
|
" /* unfitted */\n",
|
||
|
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
|
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 a.estimator_doc_link.fitted {\n",
|
||
|
" /* fitted */\n",
|
||
|
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
|
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"/* On hover */\n",
|
||
|
"#sk-container-id-4 a.estimator_doc_link:hover {\n",
|
||
|
" /* unfitted */\n",
|
||
|
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
|
" color: var(--sklearn-color-background);\n",
|
||
|
" text-decoration: none;\n",
|
||
|
"}\n",
|
||
|
"\n",
|
||
|
"#sk-container-id-4 a.estimator_doc_link.fitted:hover {\n",
|
||
|
" /* fitted */\n",
|
||
|
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
|
"}\n",
|
||
|
"</style><div id=\"sk-container-id-4\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeRegressor(max_depth=3, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" checked><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeRegressor</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeRegressor.html\">?<span>Documentation for DecisionTreeRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeRegressor(max_depth=3, random_state=42)</pre></div> </div></div></div></div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"DecisionTreeRegressor(max_depth=3, random_state=42)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tree_reg2 = DecisionTreeRegressor(max_depth=3, random_state=42)\n",
|
||
|
"tree_reg2.fit(X_quad, y_quad)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Plot models"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:20.906565Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:20.906396Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:21.488766Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:21.488082Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Text(0.5, 1.0, 'max_depth=3')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1kAAAGKCAYAAAD+NIubAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAArfJJREFUeJzs3Xd8FGX+B/DP7ibZQEIKJCGQpRN6k1AVBJEzqFjOgvLTEzQUFc6So3oIeBakiFiwEsECgnq2OxQUhFMERUPvoQSygTRCdkOAtH1+fyy72d1smd3pk+/79coryezsPM/MzjzffWaeomOMMRBCCCGEEEIIEYRe7gwQQgghhBBCiJZQJYsQQgghhBBCBESVLEIIIYQQQggREFWyCCGEEEIIIURAVMkihBBCCCGEEAFRJYsQQgghhBBCBESVLEIIIYQQQggREFWyCCGEEEIIIURAVMkihBBCCCGEEAFRJYsQDWjbti3Gjx8vS9q5ubnQ6XRYsmSJLOkTQghRJopNpCGjShYhhJPvvvsO8+fPlyXtL7/8Evfddx/at2+Pxo0bo3PnzvjHP/6BsrIyWfJDCCFEGeSMTV999RXS09PRsmVLGI1GmEwm3HPPPThw4IAs+SHKQpUsQggn3333HZ577jlZ0p40aRIOHz6MBx98EK+//jpGjRqFN998E4MHD8bly5dlyRMhhBD5yRmb9u/fj/j4eDz55JN466238Nhjj2H37t0YMGAA9u7dK0ueiHKEyZ0BQggJ5IsvvsDw4cPdlqWlpWHcuHFYvXo1JkyYIE/GCCGENFhz586tt2zChAkwmUx4++238c4778iQK6IU9CSLNGjz58+HTqfDsWPH8OCDDyI2NhaJiYl49tlnwRhDXl4e7rjjDsTExCA5ORmvvPKK871VVVWYO3cu0tLSEBsbi6ioKAwdOhRbtmxxS2PevHnQ6/XYvHmz2/JJkyYhIiIiqLtdjDG88MILMJlMaNy4MW644QYcPHjQ67plZWV46qmn0KpVKxiNRnTs2BELFy6EzWZzruPaZv3VV19FmzZt0KhRIwwbNsytucP48eOxfPlyAIBOp3P+eHrvvffQoUMHGI1G9O/fH3/88QfnffPHs4IFAH/9618BAIcPHxYkDUIIUQqKTeqITd4kJSWhcePG1Jyd0JMsQgDgvvvuQ9euXfHyyy9j/fr1eOGFF9C0aVO8++67GDFiBBYuXIjVq1dj2rRp6N+/P66//npYrVasWLECY8eOxcSJE1FeXo6srCykp6dj586d6NOnDwBgzpw5+M9//oOMjAzs378fTZo0wcaNG/H+++/j+eefR+/evTnnc+7cuXjhhRdwyy234JZbbsGuXbtw0003oaqqym29S5cuYdiwYcjPz8fkyZPRunVrbN++HbNnz8a5c+ewbNkyt/U/+ugjlJeXY8qUKbhy5Qpee+01jBgxAvv370fz5s0xefJknD17Fj/++CM+/vhjr3lbs2YNysvLMXnyZOh0OixatAh33XUXTp48ifDwcABAZWUlysvLOe1rQkKC39cLCgo4rUcIIWpFsUkdsamsrAzV1dUoKCjAsmXLYLVaceONN3LaHtEwRkgDNm/ePAaATZo0ybmspqaGmUwmptPp2Msvv+xcfuHCBdaoUSM2btw453qVlZVu27tw4QJr3rw5e+SRR9yW79+/n0VERLAJEyawCxcusJSUFNavXz9WXV3NOa9FRUUsIiKC3XrrrcxmszmXP/PMMwyAM1+MMfb888+zqKgoduzYMbdtzJo1ixkMBnbmzBnGGGOnTp1iAFijRo2Y2Wx2rvf7778zAOzpp592LpsyZQrzVmQ4ttGsWTNWWlrqXP7NN98wAOw///mPc9nKlSsZAE4/gWRkZDCDwVBvHwkhRO0oNqkrNnXu3Nn5enR0NJszZw6rra3leASJVtGTLEIAtz49BoMB/fr1g9lsRkZGhnN5XFwcOnfujJMnTzrXMxgMAACbzYaysjLYbDb069cPu3btctt+jx498Nxzz2H27NnYt28fSkpK8MMPPyAsjPsluGnTJlRVVeHvf/+7W3OIp556Ci+99JLbup9//jmGDh2K+Ph4lJSUOJePHDkSL7/8Mn7++Wc88MADzuV33nknUlJSnP8PGDAAAwcOxHfffYelS5dyyt99992H+Ph45/9Dhw4FAOfxAoD09HT8+OOPHPfYtzVr1iArKwszZsxAamoq7+0RQogSUWxSR2xauXIlrFYrTp48iZUrV+Ly5cuora2FXk+9choyqmQRAqB169Zu/8fGxiIyMrJes4DY2FicP3/e+f+HH36IV155BUeOHEF1dbVzebt27eqlMX36dKxduxY7d+7ESy+9hG7dugWVx9OnTwNAvUpFYmKiWwABgJycHOzbtw+JiYlet1VUVOT2v7eKSqdOnfDZZ59xzp/nMXTk6cKFC85lLVq0QIsWLThv05tffvkFGRkZSE9Px4svvshrW4QQomQUm9QRmwYPHuz8+/7770fXrl0BgOboauCokkUI4LzrF2gZYO/gCwCffPIJxo8fjzvvvBPTp09HUlISDAYDFixYgBMnTtR738mTJ5GTkwPAPuyrmGw2G/7yl79gxowZXl/v1KmT4GkGOl4AcPnyZVgsFk7bS05Orrds7969uP3229GjRw988cUXQd1tJYQQtaHYxJ8UsclVfHw8RowYgdWrV1Mlq4GjbyiEhOiLL75A+/bt8eWXX7o1kZg3b169dW02G8aPH4+YmBhnE4p77rkHd911F+f02rRpA8B+J7B9+/bO5cXFxW535ACgQ4cOuHjxIkaOHMlp244A6+rYsWNo27at839vIzYFa926dXj44Yc5resaAAHgxIkTGDVqFJKSkvDdd98hOjqad34IIURrKDYFj09s8iaYShvRLqpkERIix90xxpizkP/999+xY8eOes0Tli5diu3bt+Pbb7/Frbfeiq1bt+Kxxx7D9ddfz3l0vJEjRyI8PBxvvPEGbrrpJmeanqMxAcCYMWMwf/58bNy4Eenp6W6vlZWVITo62u0p0Ndff438/Hxn2/edO3fi999/x1NPPeVcJyoqyvn+uLg4Tnn2FGq794KCAtx0003Q6/XYuHGjz6YmhBDS0FFsCl6osamoqAhJSUluy3Jzc7F582b069cvpLwQ7aBKFiEhGj16NL788kv89a9/xa233opTp07hnXfeQbdu3XDx4kXneocPH8azzz6L8ePH47bbbgMArFq1Cn369MHjjz/OuW15YmIipk2bhgULFmD06NG45ZZbsHv3bnz//ff1guH06dPx7bffYvTo0Rg/fjzS0tJQUVGB/fv344svvkBubq7bezp27IghQ4bgscceQ2VlJZYtW4ZmzZq5NelIS0sDADzxxBNIT0+HwWDA/fffH9QxC7Xd+6hRo3Dy5EnMmDED27Ztw7Zt25yvNW/eHH/5y1+C3iYhhGgRxSbpYlPPnj1x4403ok+fPoiPj0dOTg6ysrJQXV2Nl19+OejtEY2Rb2BDQuTnGCa3uLjYbfm4ceNYVFRUvfWHDRvGunfvzhhjzGazsZdeeom1adOGGY1Gds0117D//ve/bNy4caxNmzaMMftQuv3792cmk4mVlZW5beu1115jANi6des457e2tpY999xzrEWLFqxRo0Zs+PDh7MCBA6xNmzZuw+Qyxlh5eTmbPXs269ixI4uIiGAJCQns2muvZUuWLGFVVVWMsbohbhcvXsxeeeUV1qpVK2Y0GtnQoUPZ3r173bZXU1PD/v73v7PExESm0+mcQ9m6bsMTADZv3jzO++cL/AynO2zYMN7bJ4QQJaHYpI7YNG/ePNavXz8WHx/PwsLCWMuWLdn999/P9u3bx3vbRP10jHFoXEoI0aTc3Fy0a9cOixcvxrRp0+TODiGEEEKxiWgCDeBPCCGEEEIIIQKiPlmEKEBxcTFqa2t9vh4REYGmTZtKmCNCCCENHcUmQkJHlSxCFKB///7OCR29GTZsGLZu3SpdhgghhDR4FJsICZ2i+mQtX74cixcvRkFBAXr37o033ngDAwYM8Lru+++/j48++ggHDhwAYB9d5qW
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x400 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"\n",
|
||
|
"def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):\n",
|
||
|
" x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n",
|
||
|
" y_pred = tree_reg.predict(x1)\n",
|
||
|
" plt.axis(axes)\n",
|
||
|
" plt.xlabel(\"$x_1$\")\n",
|
||
|
" plt.plot(X, y, \"b.\")\n",
|
||
|
" plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
|
||
|
"\n",
|
||
|
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
|
||
|
"plt.sca(axes[0])\n",
|
||
|
"plot_regression_predictions(tree_reg, X_quad, y_quad)\n",
|
||
|
"\n",
|
||
|
"th0, th1a, th1b = tree_reg.tree_.threshold[[0, 1, 4]]\n",
|
||
|
"for split, style in ((th0, \"k-\"), (th1a, \"k--\"), (th1b, \"k--\")):\n",
|
||
|
" plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)\n",
|
||
|
"plt.text(th0, 0.16, \"Depth=0\", fontsize=15)\n",
|
||
|
"plt.text(th1a + 0.01, -0.01, \"Depth=1\", horizontalalignment=\"center\", fontsize=13)\n",
|
||
|
"plt.text(th1b + 0.01, -0.01, \"Depth=1\", fontsize=13)\n",
|
||
|
"plt.ylabel(\"$y$\", rotation=0)\n",
|
||
|
"plt.legend(loc=\"upper center\", fontsize=16)\n",
|
||
|
"plt.title(\"max_depth=2\")\n",
|
||
|
"\n",
|
||
|
"plt.sca(axes[1])\n",
|
||
|
"th2s = tree_reg2.tree_.threshold[[2, 5, 9, 12]]\n",
|
||
|
"plot_regression_predictions(tree_reg2, X_quad, y_quad)\n",
|
||
|
"for split, style in ((th0, \"k-\"), (th1a, \"k--\"), (th1b, \"k--\")):\n",
|
||
|
" plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)\n",
|
||
|
"for split in th2s:\n",
|
||
|
" plt.plot([split, split], [-0.05, 0.25], \"k:\", linewidth=1)\n",
|
||
|
"plt.text(th2s[2] + 0.01, 0.15, \"Depth=2\", fontsize=13)\n",
|
||
|
"plt.title(\"max_depth=3\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Training regression models\n",
|
||
|
"\n",
|
||
|
"CART algorithm works in the same way except the cost function is based on the mean squared error (MSE):\n",
|
||
|
"\n",
|
||
|
"$$\n",
|
||
|
"J(k,t_k)=\\frac{m_\\text{left}}{m}\\text{MSE}_\\text{left}+\\frac{m_\\text{right}}{m}\\text{MSE}_\\text{right}.\n",
|
||
|
"$$"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercise 2 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "slide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"## Limitations"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Sensitivity to axis orientation"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Decision trees are sensitive to the axis orientation.\n",
|
||
|
"\n",
|
||
|
"Consider same data but with axis rotated by 45${}^\\circ$."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:21.490962Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:21.490688Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:21.495775Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:21.495198Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": "skip"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def plot_decision_boundary(clf, X, y, axes, cmap):\n",
|
||
|
" x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n",
|
||
|
" np.linspace(axes[2], axes[3], 100))\n",
|
||
|
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
|
||
|
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
|
||
|
" \n",
|
||
|
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=cmap)\n",
|
||
|
" plt.contour(x1, x2, y_pred, cmap=\"Greys\", alpha=0.8)\n",
|
||
|
" colors = {\"Wistia\": [\"#78785c\", \"#c47b27\"], \"Pastel1\": [\"red\", \"blue\"]}\n",
|
||
|
" markers = (\"o\", \"^\")\n",
|
||
|
" for idx in (0, 1):\n",
|
||
|
" plt.plot(X[:, 0][y == idx], X[:, 1][y == idx],\n",
|
||
|
" color=colors[cmap][idx], marker=markers[idx], linestyle=\"none\")\n",
|
||
|
" plt.axis(axes)\n",
|
||
|
" plt.xlabel(r\"$x_1$\")\n",
|
||
|
" plt.ylabel(r\"$x_2$\", rotation=0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"execution": {
|
||
|
"iopub.execute_input": "2025-02-27T23:21:21.497513Z",
|
||
|
"iopub.status.busy": "2025-02-27T23:21:21.497340Z",
|
||
|
"iopub.status.idle": "2025-02-27T23:21:21.688609Z",
|
||
|
"shell.execute_reply": "2025-02-27T23:21:21.687963Z"
|
||
|
},
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Text(0, 0.5, '')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1MAAAFzCAYAAADbi1ODAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAYaRJREFUeJzt3Xt8VNW5P/7PnoQEUEKIkGQCsYj1hiIIHCKotV+IEkGPcqDeOFo5CJYDrYj1VbAJWG9oRY0iP6kUfrXnV6sVDv48BlE6yimXEC7a1lpAUSmUMIBGEi5KYGZ9/5gLM5O57NmzL2vv/Xm/Xnkhw57J2gFnzbPWs55HEUIIEBERERERUVY8Vg+AiIiIiIjIjhhMERERERERacBgioiIiIiISAMGU0RERERERBowmCIiIiIiItKAwRQREREREZEGDKaIiIiIiIg0yLd6ALIIBoNobm5Gt27doCiK1cMhInINIQSOHDmCiooKeDxc44vFuYmIyBpq5yYGU2HNzc2orKy0ehhERK61d+9e9OnTx+phSIVzExGRtTLNTQymwrp16wYA+McH21DU7UyLR0OkTrCoO3Z+9jlWr16NhoYG5OXl4bXXXkNRUZHVQyNSra2tDeecc070fZhO49xERGSNtiNH8Z3BQzLOTQymwiLpE0XdzkQRJ3SyiWBREc4880x07twZeXl5yM/PR1FREYMpsiWmsXXEuYmIyFqZ5iYmpxMREREREWnAYIqIiIiIiEgDBlNEREREREQaMJgiIiIiIiLSgMEUERERERGRBgymiIiIiIiINGAwRUREREREpAGDKSIiIiIiIg0YTBEREREREWnAYIqIiIiIiEgDBlNEREREREQaMJgiIiIiIiLSgMEUERERERGRBgymiIiIiIiINGAwRUREREREpIG0wdSiRYvQt29fdO7cGVVVVdi8eXPa6w8fPozp06fD6/WisLAQ559/PlatWmXSaImIiIiIyG3yrR5AMq+99hpmzZqFxYsXo6qqCvX19Rg9ejR27tyJ0tLSDte3t7fjmmuuQWlpKZYvX47evXvjH//4B4qLi80fPBERERERuYKUwdQzzzyDKVOmYNKkSQCAxYsXo6GhAcuWLcPs2bM7XL9s2TK0tLRg48aN6NSpEwCgb9++Zg6ZiIiIiHIVCACbmoCDB4DSMuDyKiAvz+pREaUkXZpfe3s7tm3bhurq6uhjHo8H1dXVaGxsTPqcN998E8OHD8f06dNRVlaGSy65BI8//jgCgUDK73PixAm0tbXFfREREVmJcxO5WsMqKEOHwTN+AjzTpsMzfgKUocOABh7bIHlJF0x9+eWXCAQCKCsri3u8rKwMfr8/6XM+//xzLF++HIFAAKtWrUJdXR2efvppPProoym/z/z589G9e/foV2Vlpa73QURElC3OTeRaDaug3D0F2L8//nG/P/Q4AyqSlHTBlBbBYBClpaV46aWXMGTIENxyyy34+c9/jsWLF6d8zpw5c9Da2hr92rt3r4kjJiIi6ohzE7lSIACltg4QAkrCHylChH6tmxtKASSSjHRnpnr27Im8vDwcOHAg7vEDBw6gvLw86XO8Xi86deqEvJic2osuugh+vx/t7e0oKCjo8JzCwkIUFhbqO3giIqIccG4iV9rUBCVxRyqGIgTQ3AyxqQm4YoSJAyPKTLqdqYKCAgwZMgQ+ny/6WDAYhM/nw/Dhw5M+54orrsCuXbsQDAajj33yySfwer1JAykiIiIiksTBA5mvyeY6IhNJF0wBwKxZs7BkyRK8/PLL2L59O6ZNm4Zjx45Fq/vdeeedmDNnTvT6adOmoaWlBffeey8++eQTNDQ04PHHH8f06dOtugUiIiIiUqO0LPM12VxHZCLp0vwA4JZbbsGhQ4cwd+5c+P1+DBo0CKtXr44WpdizZw88ntNxYGVlJd555x3cd999uPTSS9G7d2/ce++9+NnPfmbVLRARERGRGpdXQXi9oWIT4TNSsYSiAF5vqEw6kWSkDKYAYMaMGZgxY0bSP1u7dm2Hx4YPH45NmzYZPCoiIiIi0lVeHsSjj0C5ewqEosQFVEIJlaQQjzzMflMkJSnT/IiIiIjIIIEAsGEjsHJl6FcZquSNHQPx6yVAYrExrzf0+Ngx1oyLKANpd6aIiIiISGcNq6DU1sVVzxNeL8Sjj1gfsIwdA1EzOlS17+CB0Bmpy6u4I0VSYzBFRERE5AaRxriJ55LCjXGl2AHKy2P5c7IVpvkREREROZ3bG+PKmNpIjsCdKSIiIiKnc3NjXJlTG8n2uDNFRERE5HRubYwbSW1MDCTDqY1oWGXNuMgxGEwREREROZ0bG+O6PbWRTMFgioiIiMjpwo1xI32bEglFgaiocFZj3HBqY/I7DgVUSnMzsKnJ1GGRszCYIiIiInK6cGNcAB0CKsc2xnVraiOZisEUERERkRs4sTFuuip9bkxtJNOxmh8RERGRWzipMW6mKn3h1Eb4/dEzUrGEogBer7NSG8l03JkiIiIicpNIY9xx40K/2jWQylSlz42pjWQ6BlNEREREZB/ZVOlzYmojSYXBFJHRXNp13edTMGBAHny+VHWUzH0dIiJyiGyr9I0dA7F1M4IrliP44iIEVyyH2NLEQIp0wTNTREZyadd1IYDaWg927FBQW+vByJEBpKjGa8rrEBGRg2ip0hdJbSTSGXemyF7stMvjgK7rWneF1qxRsHVr6DlbtypYs0ZbBKTX6xARkYOwSh9JhMEU2UfDKihDh8EzfgI806bDM34ClKHD5AxKHNB1PXFXKEkhpJTPmzfPg7y80BPy8gTmzVP/fL1fxyxMRyQiMokbGxCTtBhMkT3YbZfHAV3Xte4KRZ4XCISuDwQUTbtKer2OGbQGnkREmtgpS8MIrNJHEmEwRfKz4y6Pzbuua90VSnxeRLa7Snq9jlmYjkhEprFTloaRWKWPJMFgiuRnx10em+dza90VSnxeRLa7Snq9jhnslo5IRDZmtywNo7FKH0mAwRTpx6i0Azvu8tg4n1vrrlDkeR5P8gs8nux2t3J9HbPYKR2RiGzMjlkaekv2OcMJDYjJ1hhMkT6MTDuw4y6PjfO5te4KtbcDe/YAwWDyPw8GFezdG7ouHb1exwx2S0ckIhuzY5aGnpjeSJJinynKXSTtIPGTYzjtIOfc5fAuD/z+6OpbLKEogNcr3y5POJ9bqa2LT8nwekOBlIRpCLG7QsmCmciu0DXXdOz3VFgIbNoUwKFDqV+/tDR0XTp6vY4ZYs9KxQoFnqE/v/ZaRlRErhUIhIKbgwdCC36XV2lfRLNjloZejP6cQZQDBlOUmwxpB0JRoNTNhagZrX0CCe/yKHdPCb1ezJup7Ls8GDsGomY0hF6TqcHU7QoJtLcnD2YqK0NfudLrdYyUS+BJRC6gd9N2O2Zp6MGMzxlEOWAwRbkJpx2koggBNDeHgolcOo/bcJcnykZd1+20K2S1XANPInIwI3ZS7JqlkSuzPmcQacRginJjZtqBzXZ57MoOu0IyYOBJREkZtZNi5yyNXLg5vZFsgcEU5caMtAM9c87J8Xw+BTNnelBfH8SoUcaeV2LgSUQdGLmTIluWhhnzs1vTG8k2GExRboxOO9A755wcTQigttaDHTsU1NZ6MHIkzysRkcmM3kmRJUvDrPk58jkjRSVDx6Y3km2wNDrlxsgS4GxOSFmKra7HXk9EZImevdRd98kn2nsyqu2tZFT/R7Pm50AAqH8OaG1NHUjBoemNZBsMpih34bQDlJfHP+71ai9XyuaElKXEnk/s9UREpmtYBeXHP0l7SeQtyfPsc8b2SjKqL5NZ83PDKiiXXArPUwugHD+e/JoexSyLTpZjMEX6GDsGYutmBFcsR/DFRQiuWA6xpUn7G5zbmxNS1hKbDWdqMmwGn0/BgAF58Pm4Q0bkeJHdGr8/5SVJ13aMyLYwcufIjPk5Mv6vv055iQCAws5AzWjt38cKRu0WkmUYTJF+1KYdqMHqPZSFxF2pCCt3pxLPb3GHjMjB0uzWJDI828LonSOj52eVP0sFCJ3XstOiqlG7hWQpBlMkJ1bvoSwk7kpFWLk7xfN
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x400 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"np.random.seed(6)\n",
|
||
|
"X_square = np.random.rand(100, 2) - 0.5\n",
|
||
|
"y_square = (X_square[:, 0] > 0).astype(np.int64)\n",
|
||
|
"\n",
|
||
|
"angle = np.pi / 4 # 45 degrees\n",
|
||
|
"rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],\n",
|
||
|
" [np.sin(angle), np.cos(angle)]])\n",
|
||
|
"X_rotated_square = X_square.dot(rotation_matrix)\n",
|
||
|
"\n",
|
||
|
"tree_clf_square = DecisionTreeClassifier(random_state=42)\n",
|
||
|
"tree_clf_square.fit(X_square, y_square)\n",
|
||
|
"tree_clf_rotated_square = DecisionTreeClassifier(random_state=42)\n",
|
||
|
"tree_clf_rotated_square.fit(X_rotated_square, y_square)\n",
|
||
|
"\n",
|
||
|
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
|
||
|
"plt.sca(axes[0])\n",
|
||
|
"plot_decision_boundary(tree_clf_square, X_square, y_square,\n",
|
||
|
" axes=[-0.7, 0.7, -0.7, 0.7], cmap=\"Pastel1\")\n",
|
||
|
"plt.sca(axes[1])\n",
|
||
|
"plot_decision_boundary(tree_clf_rotated_square, X_rotated_square, y_square,\n",
|
||
|
" axes=[-0.7, 0.7, -0.7, 0.7], cmap=\"Pastel1\")\n",
|
||
|
"plt.ylabel(\"\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Could mitigate with principle component analysis (PCA) or with ensemble methods (see upoming lectures)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"### Decision trees are not highly stable"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"- Small changes to hyperparameters or to data may produce very different models.\n",
|
||
|
"- Even retraining with a different random seed can produce very different models."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": ""
|
||
|
},
|
||
|
"tags": []
|
||
|
},
|
||
|
"source": [
|
||
|
"Can leverage this property by averaging over many trees, which gives rises to *ensemble methods* (as discussed in the next lecture)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"editable": true,
|
||
|
"slideshow": {
|
||
|
"slide_type": "subslide"
|
||
|
},
|
||
|
"tags": [
|
||
|
"exercise_pointer"
|
||
|
]
|
||
|
},
|
||
|
"source": [
|
||
|
"**Exercises:** *You can now complete Exercises 1-2 in the exercises associated with this lecture.*"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"celltoolbar": "Slideshow",
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.11.11"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|