spce0038-machine-learning-w.../week6/slides/Lecture16_DecisionTrees.ipynb

2962 lines
195 KiB
Plaintext
Raw Normal View History

2025-02-28 11:02:51 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Lecture 16: Decision trees"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"source": [
"![](https://www.tensorflow.org/images/colab_logo_32px.png)\n",
"[Run in colab](https://colab.research.google.com/drive/1P9IoqXN9dbjJ3TN50wa8wwDdvn9P6hX7)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:19.492309Z",
"iopub.status.busy": "2025-02-27T23:21:19.492086Z",
"iopub.status.idle": "2025-02-27T23:21:19.498527Z",
"shell.execute_reply": "2025-02-27T23:21:19.497958Z"
},
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last executed: 2025-02-27 23:21:19\n"
]
}
],
"source": [
"import datetime\n",
"now = datetime.datetime.now()\n",
"print(\"Last executed: \" + now.strftime(\"%Y-%m-%d %H:%M:%S\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"Considered conceptually as a *flow diagram* or tree of decisions based on inspecting properties of data-set.\n",
"\n",
"- Can perform both classification and regression.\n",
"- A fundamental component of random forests (a powerful machine learning algorithm covered in the next lecture). \n",
"- We will learn how to visualise and make predictions using Decision Trees.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## Conceptual example"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/DecisionTree.jpg\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>\n",
"\n",
"[[Image source](https://inside-machinelearning.com/en/decision-tree-and-hyperparameters/)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## Walk-through of decision tree\n",
"\n",
"Let's consider an illustration using the Iris Data set (introduced in Lecture 3)."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Images of different Iris species\n",
"\n",
"#### Iris Setosa\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture03_Images/iris_setosa.jpg\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"#### Iris Versicolor\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture03_Images/iris_versicolor.jpg\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"#### Iris Virginica\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture03_Images/iris_virginica.jpg\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"[[Image source](https://github.com/jakevdp/sklearn_tutorial)]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"Load feature matrix, where each row correpsonds to an observed (*sampled*) flower, with a number of *features*, with corresponding target vector.\n",
"\n",
"Consider two features only for now (petal length and width)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:19.532095Z",
"iopub.status.busy": "2025-02-27T23:21:19.531889Z",
"iopub.status.idle": "2025-02-27T23:21:20.377758Z",
"shell.execute_reply": "2025-02-27T23:21:20.377159Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: #000;\n",
" --sklearn-color-text-muted: #666;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: flex;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
" align-items: start;\n",
" justify-content: space-between;\n",
" gap: 0.5em;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label .caption {\n",
" font-size: 0.6rem;\n",
" font-weight: lighter;\n",
" color: var(--sklearn-color-text-muted);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 0.5em;\n",
" text-align: center;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(max_depth=2, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(max_depth=2, random_state=42)</pre></div> </div></div></div></div>"
],
"text/plain": [
"DecisionTreeClassifier(max_depth=2, random_state=42)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X_iris = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y_iris = iris.target\n",
"\n",
"tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)\n",
"tree_clf.fit(X_iris, y_iris)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.379798Z",
"iopub.status.busy": "2025-02-27T23:21:20.379575Z",
"iopub.status.idle": "2025-02-27T23:21:20.569651Z",
"shell.execute_reply": "2025-02-27T23:21:20.568896Z"
},
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"outputs": [],
"source": [
"# We want to visualise the actual flow diagram of the tree, for this we can use graphviz\n",
"from sklearn.tree import export_graphviz\n",
"\n",
"export_graphviz(tree_clf, \n",
" out_file = './iris_tree.dot', \n",
" feature_names = iris.feature_names[ 2:], \n",
" class_names = iris.target_names, \n",
" rounded = True, \n",
" filled = True)\n",
"\n",
"#creates a dot file :( so need to convert to something more sensible\n",
"! dot -Tpng ./iris_tree.dot -o ./iris_tree.png"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"source": [
"To run dot locally you will need to install [graphviz](https://graphviz.org/).\n",
"\n",
"You can install on Mac using Homebrew:\n",
"```bash\n",
"brew install graphviz\n",
"```\n",
"\n",
"You can install on Ubuntu using apt:\n",
"```bash\n",
"sudo apt install graphviz\n",
"```\n",
"\n",
"Installation instructions for other systems are available [here](https://graphviz.org/download/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Decision tree for Iris classification (depth 2)\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree.png\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"Tree consists of a number of nodes.\n",
"- Top node is the _root node_.\n",
"- Intermediate _split nodes_.\n",
"- Lower nodes are _leaf nodes_.\n",
"\n",
"Decisions based on *features* and *thresholds*.\n",
"\n",
"Navigate tree to make predictions."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Interpreting node outputs\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"Arguments in the nodes are: \n",
"- Top argument shows the _threshold_ upon which the classification division was made.\n",
"- ```gini``` (see next slides) is a quantitative measure of impurity.\n",
"- ```samples``` denotes the number of training instances that satisfy the criteria.\n",
"- ```values``` denotes the number of training instances per class that satisfy the criteria.\n",
"- ```class``` prediction for the node."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Decision boundaries"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/decision_tree_decision_boundaries_plot.png\" alt=\"data-layout\" width=\"600\" style=\"display:block; margin:auto\"/>\n",
"\n",
"We set ```max_depth=2```, so algorithm stopped after two divisions. "
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Estimating class probabilities"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"Also want to know the _probability_ that an instance $i$ belongs to class $k$. \n",
"\n",
"Class probability founds by finding the leaf node for instance $i$, then returns ratio of training instances of class $k$ in this node.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"#### Example \n",
"For case where flower has petals=5cm long and 1.5cm wide\n",
"corresponding leaf node is at depth-2 left node. \n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/decision_tree_decision_boundaries_plot.png\" alt=\"data-layout\" width=\"600\" style=\"display:block; margin:auto\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node2.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"\n",
"So probabilities are: 0% (Setosa), 49/54=90.7% (Versicolor), 5/54=9.3% (Virginica)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.572369Z",
"iopub.status.busy": "2025-02-27T23:21:20.572178Z",
"iopub.status.idle": "2025-02-27T23:21:20.577551Z",
"shell.execute_reply": "2025-02-27T23:21:20.576948Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0.907, 0.093]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_clf.predict_proba([[5, 1.5]]).round(3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## Quality measures"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Gini impurity\n",
"\n",
"Gini impurity is defined by\n",
"$$\n",
"G_i=1-\\sum_{k=1}^{n}p^2_{i,k} ,\n",
"$$\n",
"where $p_{i,k}$ is the ratio of class $k$ instances among training instances in the $i^{\\rm th}$ node. \n",
"\n",
"$G_i=0$ means the sample is 100% _pure_ i.e. all instances are in a single class. "
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"#### Gini example calculation\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node2.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"$$\n",
"G_i = 1 - (0/54)^2 - (49/54)^2 - (5/54)^2= 0.168\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Entropy \n",
"\n",
"Alternative to Gini is to use entropy as the purity measure \n",
"$$\n",
"H_i=-\\sum_{k=1}^n p_{i,k}\\log_2(p_{i,k}),\n",
"$$\n",
"for $p_{i,k}\\not=0$. \n",
"\n",
"Measures the information content of variable (number of bits required to encode)."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"#### Entropy example calculation\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_1node2.png\" alt=\"data-layout\" width=\"300\" style=\"display:block; margin:auto\"/>\n",
"\n",
"$$\n",
"H_i = - (49/54) \\log_2(49/54) - (5/54) \\log_2(5/54)= 0.445\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Does it make a difference what quality measure is used? \n",
"\n",
"Not usually, although Gini tends to isolate the most frequent classes, and entropy leads to more \"balanced\" trees. Entropy is slightly more expensive to compute."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Decision tree using entropy criterion"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.579634Z",
"iopub.status.busy": "2025-02-27T23:21:20.579461Z",
"iopub.status.idle": "2025-02-27T23:21:20.585935Z",
"shell.execute_reply": "2025-02-27T23:21:20.585283Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-2 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: #000;\n",
" --sklearn-color-text-muted: #666;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-2 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-2 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-2 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-2 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-2 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: flex;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
" align-items: start;\n",
" justify-content: space-between;\n",
" gap: 0.5em;\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label .caption {\n",
" font-size: 0.6rem;\n",
" font-weight: lighter;\n",
" color: var(--sklearn-color-text-muted);\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-2 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-2 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-2 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-2 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 0.5em;\n",
" text-align: center;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-2 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeClassifier(criterion=&#x27;entropy&#x27;, max_depth=2)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeClassifier</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeClassifier.html\">?<span>Documentation for DecisionTreeClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeClassifier(criterion=&#x27;entropy&#x27;, max_depth=2)</pre></div> </div></div></div></div>"
],
"text/plain": [
"DecisionTreeClassifier(criterion='entropy', max_depth=2)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Redo the first example but use entropy instead\n",
"tree_clf = DecisionTreeClassifier(max_depth = 2,criterion='entropy') #making a decision tree of depth 2 from the data \n",
"tree_clf.fit(X_iris, y_iris)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.587706Z",
"iopub.status.busy": "2025-02-27T23:21:20.587538Z",
"iopub.status.idle": "2025-02-27T23:21:20.733060Z",
"shell.execute_reply": "2025-02-27T23:21:20.732336Z"
},
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"outputs": [],
"source": [
"export_graphviz(tree_clf, \n",
" out_file = './iris_tree_entropy.dot', \n",
" feature_names = iris.feature_names[ 2:], \n",
" class_names = iris.target_names, \n",
" rounded = True, \n",
" filled = True)\n",
"! dot -Tpng ./iris_tree_entropy.dot -o ./iris_tree_entropy.png"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"#### Entropy tree\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree_entropy.png\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"#### Gini tree (for reference)\n",
"\n",
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/iris_tree.png\" alt=\"data-layout\" width=\"500\" style=\"display:block; margin:auto\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## CART training algorithm\n",
"\n",
"Classification And Regression Tree (CART) algorihtm can be used to train decision trees -- also called *growing trees* algorithm (used by SciKit Learn)."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"- Splits the sample into two subsets using a single feature $k$ at threshold $t_k$\n",
"- Chooses feature $k$ and threshold $t_k$ by finding pair that produces purest subset, weighted by their size.\n",
"\n",
"Cost function minimized for each split:\n",
"\n",
"$$\n",
"J(k,t_k)=\\frac{m_\\text{left}}{m}G_\\text{left}+\\frac{m_\\text{right}}{m}G_\\text{right} ,\n",
"$$\n",
"\n",
"where $G_\\text{left/right}$ measures the impurity of the left/right subset and $m_\\text{left/right}$ is the number of instances in the left/right subset."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"Note that the CART algorithm:\n",
"- Only splits data in two at each stage, i.e. is binary.\n",
"- Is a greedy algorithm. It searches for the optimal split at each level, then repeats for subsequent levels. There is no guarantee that the overall optimal tree is found. Nevertheless, usually produces a tree that is reasonably good."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## Regularisation and hyperparameters\n",
"\n",
"Decision Trees are **non-parametric** classification algorithms since the number of parameters is not determined prior to training.\n",
"\n",
"Tends to overfit if not careful. "
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"Need to regularise the problem. This can be done by the restricting the complexity of the tree, for example though the following SciKit Learn parameters.\n",
"\n",
"- `max_depth`: maximum depth of the tree.\n",
"- `max_features` maximum number of features used in splitting at each node.\n",
"- `max_leaf_nodes` maximum number of leaf nodes.\n",
"- `min_samples_split`: mimimum number of samples a node must have before it can be split.\n",
"- `min_samples_leaf`: minimum number a leaf can have to be created.\n",
"- `min_weight_fraction`: Same as `min_samples_leaf` but expressed as a fraction of total samples.\n",
"\n",
"Generally, increasing ```min_*``` hyperparameters or reducing ```max_*``` hyperparameters will regularise the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"Other algorithms _prune_, i.e. make a (relatively) unrestricted tree then remove statistically insignificant nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 1 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## Decision trees for regression \n",
"\n",
"Decision trees can also be used for regression tasks."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Train regression model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.735593Z",
"iopub.status.busy": "2025-02-27T23:21:20.735405Z",
"iopub.status.idle": "2025-02-27T23:21:20.743106Z",
"shell.execute_reply": "2025-02-27T23:21:20.742456Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-3 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: #000;\n",
" --sklearn-color-text-muted: #666;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-3 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-3 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-3 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-3 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-3 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-3 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-3 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: flex;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
" align-items: start;\n",
" justify-content: space-between;\n",
" gap: 0.5em;\n",
"}\n",
"\n",
"#sk-container-id-3 label.sk-toggleable__label .caption {\n",
" font-size: 0.6rem;\n",
" font-weight: lighter;\n",
" color: var(--sklearn-color-text-muted);\n",
"}\n",
"\n",
"#sk-container-id-3 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-3 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-3 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-3 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-3 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-3 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-3 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-3 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 0.5em;\n",
" text-align: center;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-3 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-3 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-3 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-3 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeRegressor(max_depth=2, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeRegressor</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeRegressor.html\">?<span>Documentation for DecisionTreeRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeRegressor(max_depth=2, random_state=42)</pre></div> </div></div></div></div>"
],
"text/plain": [
"DecisionTreeRegressor(max_depth=2, random_state=42)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"from sklearn.tree import DecisionTreeRegressor\n",
"\n",
"np.random.seed(42)\n",
"X_quad = np.random.rand(200, 1) - 0.5 # a single random input feature\n",
"y_quad = X_quad ** 2 + 0.025 * np.random.randn(200, 1)\n",
"\n",
"tree_reg = DecisionTreeRegressor(max_depth=2, random_state=42)\n",
"tree_reg.fit(X_quad, y_quad)"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Visualise tree"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.744939Z",
"iopub.status.busy": "2025-02-27T23:21:20.744742Z",
"iopub.status.idle": "2025-02-27T23:21:20.896284Z",
"shell.execute_reply": "2025-02-27T23:21:20.895553Z"
},
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"outputs": [],
"source": [
"export_graphviz(\n",
" tree_reg,\n",
" out_file=\"./regression_tree.dot\",\n",
" feature_names=[\"x1\"],\n",
" rounded=True,\n",
" filled=True\n",
")\n",
"\n",
"#creates a dot file :( so need to convert to something more sensible\n",
"! dot -Tpng ./regression_tree.dot -o ./regression_tree.png"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"<img src=\"https://raw.githubusercontent.com/astro-informatics/course_mlbd_images/master/Lecture16_Images/regression_tree.png\" alt=\"data-layout\" width=\"700\" style=\"display:block; margin:auto\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Train a deeper model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.898435Z",
"iopub.status.busy": "2025-02-27T23:21:20.898251Z",
"iopub.status.idle": "2025-02-27T23:21:20.904714Z",
"shell.execute_reply": "2025-02-27T23:21:20.904136Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-4 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: #000;\n",
" --sklearn-color-text-muted: #666;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-4 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-4 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-4 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-4 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-4 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-4 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: flex;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
" align-items: start;\n",
" justify-content: space-between;\n",
" gap: 0.5em;\n",
"}\n",
"\n",
"#sk-container-id-4 label.sk-toggleable__label .caption {\n",
" font-size: 0.6rem;\n",
" font-weight: lighter;\n",
" color: var(--sklearn-color-text-muted);\n",
"}\n",
"\n",
"#sk-container-id-4 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-4 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-4 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-4 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-4 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-4 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 0.5em;\n",
" text-align: center;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-4 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-4 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-4 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-4 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-4\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeRegressor(max_depth=3, random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" checked><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DecisionTreeRegressor</div></div><div><a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.6/modules/generated/sklearn.tree.DecisionTreeRegressor.html\">?<span>Documentation for DecisionTreeRegressor</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\"><pre>DecisionTreeRegressor(max_depth=3, random_state=42)</pre></div> </div></div></div></div>"
],
"text/plain": [
"DecisionTreeRegressor(max_depth=3, random_state=42)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree_reg2 = DecisionTreeRegressor(max_depth=3, random_state=42)\n",
"tree_reg2.fit(X_quad, y_quad)"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Plot models"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:20.906565Z",
"iopub.status.busy": "2025-02-27T23:21:20.906396Z",
"iopub.status.idle": "2025-02-27T23:21:21.488766Z",
"shell.execute_reply": "2025-02-27T23:21:21.488082Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'max_depth=3')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1kAAAGKCAYAAAD+NIubAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAArfJJREFUeJzs3Xd8FGX+B/DP7ibZQEIKJCGQpRN6k1AVBJEzqFjOgvLTEzQUFc6So3oIeBakiFiwEsECgnq2OxQUhFMERUPvoQSygTRCdkOAtH1+fyy72d1smd3pk+/79coryezsPM/MzjzffWaeomOMMRBCCCGEEEIIEYRe7gwQQgghhBBCiJZQJYsQQgghhBBCBESVLEIIIYQQQggREFWyCCGEEEIIIURAVMkihBBCCCGEEAFRJYsQQgghhBBCBESVLEIIIYQQQggREFWyCCGEEEIIIURAVMkihBBCCCGEEAFRJYsQDWjbti3Gjx8vS9q5ubnQ6XRYsmSJLOkTQghRJopNpCGjShYhhJPvvvsO8+fPlyXtL7/8Evfddx/at2+Pxo0bo3PnzvjHP/6BsrIyWfJDCCFEGeSMTV999RXS09PRsmVLGI1GmEwm3HPPPThw4IAs+SHKQpUsQggn3333HZ577jlZ0p40aRIOHz6MBx98EK+//jpGjRqFN998E4MHD8bly5dlyRMhhBD5yRmb9u/fj/j4eDz55JN466238Nhjj2H37t0YMGAA9u7dK0ueiHKEyZ0BQggJ5IsvvsDw4cPdlqWlpWHcuHFYvXo1JkyYIE/GCCGENFhz586tt2zChAkwmUx4++238c4778iQK6IU9CSLNGjz58+HTqfDsWPH8OCDDyI2NhaJiYl49tlnwRhDXl4e7rjjDsTExCA5ORmvvPKK871VVVWYO3cu0tLSEBsbi6ioKAwdOhRbtmxxS2PevHnQ6/XYvHmz2/JJkyYhIiIiqLtdjDG88MILMJlMaNy4MW644QYcPHjQ67plZWV46qmn0KpVKxiNRnTs2BELFy6EzWZzruPaZv3VV19FmzZt0KhRIwwbNsytucP48eOxfPlyAIBOp3P+eHrvvffQoUMHGI1G9O/fH3/88QfnffPHs4IFAH/9618BAIcPHxYkDUIIUQqKTeqITd4kJSWhcePG1Jyd0JMsQgDgvvvuQ9euXfHyyy9j/fr1eOGFF9C0aVO8++67GDFiBBYuXIjVq1dj2rRp6N+/P66//npYrVasWLECY8eOxcSJE1FeXo6srCykp6dj586d6NOnDwBgzpw5+M9//oOMjAzs378fTZo0wcaNG/H+++/j+eefR+/evTnnc+7cuXjhhRdwyy234JZbbsGuXbtw0003oaqqym29S5cuYdiwYcjPz8fkyZPRunVrbN++HbNnz8a5c+ewbNkyt/U/+ugjlJeXY8qUKbhy5Qpee+01jBgxAvv370fz5s0xefJknD17Fj/++CM+/vhjr3lbs2YNysvLMXnyZOh0OixatAh33XUXTp48ifDwcABAZWUlysvLOe1rQkKC39cLCgo4rUcIIWpFsUkdsamsrAzV1dUoKCjAsmXLYLVaceONN3LaHtEwRkgDNm/ePAaATZo0ybmspqaGmUwmptPp2Msvv+xcfuHCBdaoUSM2btw453qVlZVu27tw4QJr3rw5e+SRR9yW79+/n0VERLAJEyawCxcusJSUFNavXz9WXV3NOa9FRUUsIiKC3XrrrcxmszmXP/PMMwyAM1+MMfb888+zqKgoduzYMbdtzJo1ixkMBnbmzBnGGGOnTp1iAFijRo2Y2Wx2rvf7778zAOzpp592LpsyZQrzVmQ4ttGsWTNWWlrqXP7NN98wAOw///mPc9nKlSsZAE4/gWRkZDCDwVBvHwkhRO0oNqkrNnXu3Nn5enR0NJszZw6rra3leASJVtGTLEIAtz49BoMB/fr1g9lsRkZGhnN5XFwcOnfujJMnTzrXMxgMAACbzYaysjLYbDb069cPu3btctt+jx498Nxzz2H27NnYt28fSkpK8MMPPyAsjPsluGnTJlRVVeHvf/+7W3OIp556Ci+99JLbup9//jmGDh2K+Ph4lJSUOJePHDkSL7/8Mn7++Wc88MADzuV33nknUlJSnP8PGDAAAwcOxHfffYelS5dyyt99992H+Ph45/9Dhw4FAOfxAoD09HT8+OOPHPfYtzVr1iArKwszZsxAamoq7+0RQogSUWxSR2xauXIlrFYrTp48iZUrV+Ly5cuora2FXk+9choyqmQRAqB169Zu/8fGxiIyMrJes4DY2FicP3/e+f+HH36IV155BUeOHEF1dbVzebt27eqlMX36dKxduxY7d+7ESy+9hG7dugWVx9OnTwNAvUpFYmKiWwABgJycHOzbtw+JiYlet1VUVOT2v7eKSqdOnfDZZ59xzp/nMXTk6cKFC85lLVq0QIsWLThv05tffvkFGRkZSE9Px4svvshrW4QQomQUm9QRmwYPHuz8+/7770fXrl0BgOboauCokkUI4LzrF2gZYO/gCwCffPIJxo8fjzvvvBPTp09HUlISDAYDFixYgBMnTtR738mTJ5GTkwPAPuyrmGw2G/7yl79gxowZXl/v1KmT4GkGOl4AcPnyZVgsFk7bS05Orrds7969uP3229GjRw988cUXQd1tJYQQtaHYxJ8UsclVfHw8RowYgdWrV1Mlq4GjbyiEhOiLL75A+/bt8eWXX7o1kZg3b169dW02G8aPH4+YmBhnE4p77rkHd911F+f02rRpA8B+J7B9+/bO5cXFxW535ACgQ4cOuHjxIkaOHMlp244A6+rYsWNo27at839vIzYFa926dXj44Yc5resaAAHgxIkTGDVqFJKSkvDdd98hOjqad34IIURrKDYFj09s8iaYShvRLqpkERIix90xxpizkP/999+xY8eOes0Tli5diu3bt+Pbb7/Frbfeiq1bt+Kxxx7D9ddfz3l0vJEjRyI8PBxvvPEGbrrpJmeanqMxAcCYMWMwf/58bNy4Eenp6W6vlZWVITo62u0p0Ndff438/Hxn2/edO3fi999/x1NPPeVcJyoqyvn+uLg4Tnn2FGq794KCAtx0003Q6/XYuHGjz6YmhBDS0FFsCl6osamoqAhJSUluy3Jzc7F582b069cvpLwQ7aBKFiEhGj16NL788kv89a9/xa233opTp07hnXfeQbdu3XDx4kXneocPH8azzz6L8ePH47bbbgMArFq1Cn369MHjjz/OuW15YmIipk2bhgULFmD06NG45ZZbsHv3bnz//ff1guH06dPx7bffYvTo0Rg/fjzS0tJQUVGB/fv344svvkBubq7bezp27IghQ4bgscceQ2VlJZYtW4ZmzZq5NelIS0sDADzxxBNIT0+HwWDA/fffH9QxC7Xd+6hRo3Dy5EnMmDED27Ztw7Zt25yvNW/eHH/5y1+C3iYhhGgRxSbpYlPPnj1x4403ok+fPoiPj0dOTg6ysrJQXV2Nl19+OejtEY2Rb2BDQuTnGCa3uLjYbfm4ceNYVFRUvfWHDRvGunfvzhhjzGazsZdeeom1adOGGY1Gds0117D//ve/bNy4caxNmzaMMftQuv3792cmk4mVlZW5beu1115jANi6des457e2tpY999xzrEWLFqxRo0Zs+PDh7MCBA6xNmzZuw+Qyxlh5eTmbPXs269ixI4uIiGAJCQns2muvZUuWLGFVVVWMsbohbhcvXsxeeeUV1qpVK2Y0GtnQoUPZ3r173bZXU1PD/v73v7PExESm0+mcQ9m6bsMTADZv3jzO++cL/AynO2zYMN7bJ4QQJaHYpI7YNG/ePNavXz8WHx/PwsLCWMuWLdn999/P9u3bx3vbRP10jHFoXEoI0aTc3Fy0a9cOixcvxrRp0+TODiGEEEKxiWgCDeBPCCGEEEIIIQKiPlmEKEBxcTFqa2t9vh4REYGmTZtKmCNCCCENHcUmQkJHlSxCFKB///7OCR29GTZsGLZu3SpdhgghhDR4FJsICZ2i+mQtX74cixcvRkFBAXr37o033ngDAwYM8Lru+++/j48++ggHDhwAYB9d5qW
"text/plain": [
"<Figure size 1000x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):\n",
" x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)\n",
" y_pred = tree_reg.predict(x1)\n",
" plt.axis(axes)\n",
" plt.xlabel(\"$x_1$\")\n",
" plt.plot(X, y, \"b.\")\n",
" plt.plot(x1, y_pred, \"r.-\", linewidth=2, label=r\"$\\hat{y}$\")\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_regression_predictions(tree_reg, X_quad, y_quad)\n",
"\n",
"th0, th1a, th1b = tree_reg.tree_.threshold[[0, 1, 4]]\n",
"for split, style in ((th0, \"k-\"), (th1a, \"k--\"), (th1b, \"k--\")):\n",
" plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)\n",
"plt.text(th0, 0.16, \"Depth=0\", fontsize=15)\n",
"plt.text(th1a + 0.01, -0.01, \"Depth=1\", horizontalalignment=\"center\", fontsize=13)\n",
"plt.text(th1b + 0.01, -0.01, \"Depth=1\", fontsize=13)\n",
"plt.ylabel(\"$y$\", rotation=0)\n",
"plt.legend(loc=\"upper center\", fontsize=16)\n",
"plt.title(\"max_depth=2\")\n",
"\n",
"plt.sca(axes[1])\n",
"th2s = tree_reg2.tree_.threshold[[2, 5, 9, 12]]\n",
"plot_regression_predictions(tree_reg2, X_quad, y_quad)\n",
"for split, style in ((th0, \"k-\"), (th1a, \"k--\"), (th1b, \"k--\")):\n",
" plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)\n",
"for split in th2s:\n",
" plt.plot([split, split], [-0.05, 0.25], \"k:\", linewidth=1)\n",
"plt.text(th2s[2] + 0.01, 0.15, \"Depth=2\", fontsize=13)\n",
"plt.title(\"max_depth=3\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Training regression models\n",
"\n",
"CART algorithm works in the same way except the cost function is based on the mean squared error (MSE):\n",
"\n",
"$$\n",
"J(k,t_k)=\\frac{m_\\text{left}}{m}\\text{MSE}_\\text{left}+\\frac{m_\\text{right}}{m}\\text{MSE}_\\text{right}.\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercise 2 in the exercises associated with this lecture.*"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "slide"
},
"tags": []
},
"source": [
"## Limitations"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"### Sensitivity to axis orientation"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"Decision trees are sensitive to the axis orientation.\n",
"\n",
"Consider same data but with axis rotated by 45${}^\\circ$."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:21.490962Z",
"iopub.status.busy": "2025-02-27T23:21:21.490688Z",
"iopub.status.idle": "2025-02-27T23:21:21.495775Z",
"shell.execute_reply": "2025-02-27T23:21:21.495198Z"
},
"slideshow": {
"slide_type": "skip"
},
"tags": []
},
"outputs": [],
"source": [
"def plot_decision_boundary(clf, X, y, axes, cmap):\n",
" x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),\n",
" np.linspace(axes[2], axes[3], 100))\n",
" X_new = np.c_[x1.ravel(), x2.ravel()]\n",
" y_pred = clf.predict(X_new).reshape(x1.shape)\n",
" \n",
" plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=cmap)\n",
" plt.contour(x1, x2, y_pred, cmap=\"Greys\", alpha=0.8)\n",
" colors = {\"Wistia\": [\"#78785c\", \"#c47b27\"], \"Pastel1\": [\"red\", \"blue\"]}\n",
" markers = (\"o\", \"^\")\n",
" for idx in (0, 1):\n",
" plt.plot(X[:, 0][y == idx], X[:, 1][y == idx],\n",
" color=colors[cmap][idx], marker=markers[idx], linestyle=\"none\")\n",
" plt.axis(axes)\n",
" plt.xlabel(r\"$x_1$\")\n",
" plt.ylabel(r\"$x_2$\", rotation=0)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"editable": true,
"execution": {
"iopub.execute_input": "2025-02-27T23:21:21.497513Z",
"iopub.status.busy": "2025-02-27T23:21:21.497340Z",
"iopub.status.idle": "2025-02-27T23:21:21.688609Z",
"shell.execute_reply": "2025-02-27T23:21:21.687963Z"
},
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, '')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1MAAAFzCAYAAADbi1ODAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAYaRJREFUeJzt3Xt8VNW5P/7PnoQEUEKIkGQCsYj1hiIIHCKotV+IEkGPcqDeOFo5CJYDrYj1VbAJWG9oRY0iP6kUfrXnV6sVDv48BlE6yimXEC7a1lpAUSmUMIBGEi5KYGZ9/5gLM5O57NmzL2vv/Xm/Xnkhw57J2gFnzbPWs55HEUIIEBERERERUVY8Vg+AiIiIiIjIjhhMERERERERacBgioiIiIiISAMGU0RERERERBowmCIiIiIiItKAwRQREREREZEGDKaIiIiIiIg0yLd6ALIIBoNobm5Gt27doCiK1cMhInINIQSOHDmCiooKeDxc44vFuYmIyBpq5yYGU2HNzc2orKy0ehhERK61d+9e9OnTx+phSIVzExGRtTLNTQymwrp16wYA+McH21DU7UyLR0OkTrCoO3Z+9jlWr16NhoYG5OXl4bXXXkNRUZHVQyNSra2tDeecc070fZhO49xERGSNtiNH8Z3BQzLOTQymwiLpE0XdzkQRJ3SyiWBREc4880x07twZeXl5yM/PR1FREYMpsiWmsXXEuYmIyFqZ5iYmpxMREREREWnAYIqIiIiIiEgDBlNEREREREQaMJgiIiIiIiLSgMEUERERERGRBgymiIiIiIiINGAwRUREREREpAGDKSIiIiIiIg0YTBEREREREWnAYIqIiIiIiEgDBlNEREREREQaMJgiIiIiIiLSgMEUERERERGRBgymiIiIiIiINGAwRUREREREpIG0wdSiRYvQt29fdO7cGVVVVdi8eXPa6w8fPozp06fD6/WisLAQ559/PlatWmXSaImIiIiIyG3yrR5AMq+99hpmzZqFxYsXo6qqCvX19Rg9ejR27tyJ0tLSDte3t7fjmmuuQWlpKZYvX47evXvjH//4B4qLi80fPBERERERuYKUwdQzzzyDKVOmYNKkSQCAxYsXo6GhAcuWLcPs2bM7XL9s2TK0tLRg48aN6NSpEwCgb9++Zg6ZiIiIiHIVCACbmoCDB4DSMuDyKiAvz+pREaUkXZpfe3s7tm3bhurq6uhjHo8H1dXVaGxsTPqcN998E8OHD8f06dNRVlaGSy65BI8//jgCgUDK73PixAm0tbXFfREREVmJcxO5WsMqKEOHwTN+AjzTpsMzfgKUocOABh7bIHlJF0x9+eWXCAQCKCsri3u8rKwMfr8/6XM+//xzLF++HIFAAKtWrUJdXR2efvppPProoym/z/z589G9e/foV2Vlpa73QURElC3OTeRaDaug3D0F2L8//nG/P/Q4AyqSlHTBlBbBYBClpaV46aWXMGTIENxyyy34+c9/jsWLF6d8zpw5c9Da2hr92rt3r4kjJiIi6ohzE7lSIACltg4QAkrCHylChH6tmxtKASSSjHRnpnr27Im8vDwcOHAg7vEDBw6gvLw86XO8Xi86deqEvJic2osuugh+vx/t7e0oKCjo8JzCwkIUFhbqO3giIqIccG4iV9rUBCVxRyqGIgTQ3AyxqQm4YoSJAyPKTLqdqYKCAgwZMgQ+ny/6WDAYhM/nw/Dhw5M+54orrsCuXbsQDAajj33yySfwer1JAykiIiIiksTBA5mvyeY6IhNJF0wBwKxZs7BkyRK8/PLL2L59O6ZNm4Zjx45Fq/vdeeedmDNnTvT6adOmoaWlBffeey8++eQTNDQ04PHHH8f06dOtugUiIiIiUqO0LPM12VxHZCLp0vwA4JZbbsGhQ4cwd+5c+P1+DBo0CKtXr44WpdizZw88ntNxYGVlJd555x3cd999uPTSS9G7d2/ce++9+NnPfmbVLRARERGRGpdXQXi9oWIT4TNSsYSiAF5vqEw6kWSkDKYAYMaMGZgxY0bSP1u7dm2Hx4YPH45NmzYZPCoiIiIi0lVeHsSjj0C5ewqEosQFVEIJlaQQjzzMflMkJSnT/IiIiIjIIIEAsGEjsHJl6FcZquSNHQPx6yVAYrExrzf0+Ngx1oyLKANpd6aIiIiISGcNq6DU1sVVzxNeL8Sjj1gfsIwdA1EzOlS17+CB0Bmpy6u4I0VSYzBFRERE5AaRxriJ55LCjXGl2AHKy2P5c7IVpvkREREROZ3bG+PKmNpIjsCdKSIiIiKnc3NjXJlTG8n2uDNFRERE5HRubYwbSW1MDCTDqY1oWGXNuMgxGEwREREROZ0bG+O6PbWRTMFgioiIiMjpwo1xI32bEglFgaiocFZj3HBqY/I7DgVUSnMzsKnJ1GGRszCYIiIiInK6cGNcAB0CKsc2xnVraiOZisEUERERkRs4sTFuuip9bkxtJNOxmh8RERGRWzipMW6mKn3h1Eb4/dEzUrGEogBer7NSG8l03JkiIiIicpNIY9xx40K/2jWQylSlz42pjWQ6BlNEREREZB/ZVOlzYmojSYXBFJHRXNp13edTMGBAHny+VHWUzH0dIiJyiGyr9I0dA7F1M4IrliP44iIEVyyH2NLEQIp0wTNTREZyadd1IYDaWg927FBQW+vByJEBpKjGa8rrEBGRg2ip0hdJbSTSGXemyF7stMvjgK7rWneF1qxRsHVr6DlbtypYs0ZbBKTX6xARkYOwSh9JhMEU2UfDKihDh8EzfgI806bDM34ClKHD5AxKHNB1PXFXKEkhpJTPmzfPg7y80BPy8gTmzVP/fL1fxyxMRyQiMokbGxCTtBhMkT3YbZfHAV3Xte4KRZ4XCISuDwQUTbtKer2OGbQGnkREmtgpS8MIrNJHEmEwRfKz4y6Pzbuua90VSnxeRLa7Snq9jlmYjkhEprFTloaRWKWPJMFgiuRnx10em+dza90VSnxeRLa7Snq9jhnslo5IRDZmtywNo7FKH0mAwRTpx6i0Azvu8tg4n1vrrlDkeR5P8gs8nux2t3J9HbPYKR2RiGzMjlkaekv2OcMJDYjJ1hhMkT6MTDuw4y6PjfO5te4KtbcDe/YAwWDyPw8GFezdG7ouHb1exwx2S0ckIhuzY5aGnpjeSJJinynKXSTtIPGTYzjtIOfc5fAuD/z+6OpbLKEogNcr3y5POJ9bqa2LT8nwekOBlIRpCLG7QsmCmciu0DXXdOz3VFgIbNoUwKFDqV+/tDR0XTp6vY4ZYs9KxQoFnqE/v/ZaRlRErhUIhIKbgwdCC36XV2lfRLNjloZejP6cQZQDBlOUmwxpB0JRoNTNhagZrX0CCe/yKHdPCb1ezJup7Ls8GDsGomY0hF6TqcHU7QoJtLcnD2YqK0NfudLrdYyUS+BJRC6gd9N2O2Zp6MGMzxlEOWAwRbkJpx2koggBNDeHgolcOo/bcJcnykZd1+20K2S1XANPInIwI3ZS7JqlkSuzPmcQacRginJjZtqBzXZ57MoOu0IyYOBJREkZtZNi5yyNXLg5vZFsgcEU5caMtAM9c87J8Xw+BTNnelBfH8SoUcaeV2LgSUQdGLmTIluWhhnzs1vTG8k2GExRboxOO9A755wcTQigttaDHTsU1NZ6MHIkzysRkcmM3kmRJUvDrPk58jkjRSVDx6Y3km2wNDrlxsgS4GxOSFmKra7HXk9EZImevdRd98kn2nsyqu2tZFT/R7Pm50AAqH8OaG1NHUjBoemNZBsMpih34bQDlJfHP+71ai9XyuaElKXEnk/s9UREpmtYBeXHP0l7SeQtyfPsc8b2SjKqL5NZ83PDKiiXXArPUwugHD+e/JoexSyLTpZjMEX6GDsGYutmBFcsR/DFRQiuWA6xpUn7G5zbmxNS1hKbDWdqMmwGn0/BgAF58Pm4Q0bkeJHdGr8/5SVJ13aMyLYwcufIjPk5Mv6vv055iQCAws5AzWjt38cKRu0WkmUYTJF+1KYdqMHqPZSFxF2pCCt3pxLPb3GHjMjB0uzWJDI828LonSOj52eVP0sFCJ3XstOiqlG7hWQpBlMkJ1bvoSwk7kpFWLk7xfN
"text/plain": [
"<Figure size 1000x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"np.random.seed(6)\n",
"X_square = np.random.rand(100, 2) - 0.5\n",
"y_square = (X_square[:, 0] > 0).astype(np.int64)\n",
"\n",
"angle = np.pi / 4 # 45 degrees\n",
"rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],\n",
" [np.sin(angle), np.cos(angle)]])\n",
"X_rotated_square = X_square.dot(rotation_matrix)\n",
"\n",
"tree_clf_square = DecisionTreeClassifier(random_state=42)\n",
"tree_clf_square.fit(X_square, y_square)\n",
"tree_clf_rotated_square = DecisionTreeClassifier(random_state=42)\n",
"tree_clf_rotated_square.fit(X_rotated_square, y_square)\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)\n",
"plt.sca(axes[0])\n",
"plot_decision_boundary(tree_clf_square, X_square, y_square,\n",
" axes=[-0.7, 0.7, -0.7, 0.7], cmap=\"Pastel1\")\n",
"plt.sca(axes[1])\n",
"plot_decision_boundary(tree_clf_rotated_square, X_rotated_square, y_square,\n",
" axes=[-0.7, 0.7, -0.7, 0.7], cmap=\"Pastel1\")\n",
"plt.ylabel(\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"Could mitigate with principle component analysis (PCA) or with ensemble methods (see upoming lectures)."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": []
},
"source": [
"### Decision trees are not highly stable"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"- Small changes to hyperparameters or to data may produce very different models.\n",
"- Even retraining with a different random seed can produce very different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"Can leverage this property by averaging over many trees, which gives rises to *ensemble methods* (as discussed in the next lecture)."
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": "subslide"
},
"tags": [
"exercise_pointer"
]
},
"source": [
"**Exercises:** *You can now complete Exercises 1-2 in the exercises associated with this lecture.*"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 4
}