Learn about regression trees in this tutorial by Giuseppe Ciaburro, a Ph.D. in environmental technical physics with over 15 years of experience in programming with Python, R, and MATLAB, in the field of combustion, acoustics, and noise control.
Decision trees are used to predict a response (class y) from several input variables: x1, x2,…,xn. If y is a continuous response, it’s called a regression tree; if y is categorical, it’s called a classification tree. That’s why these methods are often called Classification and Regression Trees (CART). The algorithm checks the value of an input (xi) at every node of the tree and continues to the left or right branch based on the (binary) answer. When you reach a leaf, you will find the prediction.
The algorithm starts from grouped data into a single node (root node) and executes a comprehensive recursion of all possible subdivisions at every step. At each step, the best subdivision (the one that produces as many homogeneous branches as possible) is chosen.
In regression trees, you try to partition the data space into small enough parts where you can apply a simple yet different model on each part. The non-leaf part of the tree is just the procedure to determine for each data x the model you will use to predict it.
A regression tree is formed by a series of nodes that split the root branch into two child branches. Such subdivision continues to cascade. Each new branch can, then, go to another node or remain a leaf with the predicted value.
Starting from the whole dataset (root), the algorithm creates the tree through the following procedure:
- Identify the best functionality to divide the X1 dataset and the best s1 division value. The left-hand branch will be the set of observations where X1 is below s1, while the right-hand branch comprises the set of observations in which X1 is greater than or equal to s1.
- This operation is then recursively executed again (independently) for every branch until there is no possibility of division.
- When the divisions are completed, a leaf is created, which indicates the output values.
Suppose you have a variable response to only two continuous predictors (X1 and X2) and four division values (s1, s2, s3, s4). The following figure proposes a way to represent the whole dataset graphically:
The goal of a regression tree is to encapsulate the whole dataset in the smallest possible tree. To minimize the tree size, the simplest possible explanation for a set of observations is preferred over other explanations. All this is justified by the fact that small trees are much easier to comprehend than large trees.
You saw how the regression tree algorithm works. These steps can be summarized in the following processes:
- Splitting: The dataset is partitioned into subsets. The split operation is based on a set of rules, for example, sums of squares from the whole dataset. The leaf node contains a small subset of the observations. Splitting continues until a leaf node is constructed.
- Pruning: In this process, the tree branches are shortened. The tree is reduced by transforming a few nodes of branches into leaf nodes and removing leaf nodes under the original branch. Care must be taken as the lower branches can be strongly influenced by abnormal values. Pruning allows you to find the next largest tree and minimize the problem. A simpler tree often avoids overfitting.
- Tree selection: Finally, the smallest tree that matches the data is selected. This process is executed by choosing the tree that produces the lowest cross-validated error.
To fit a regression tree in R, you can use the tree() function implemented in the tree package. In this package, a tree is grown via binary recursive partitioning by using the response in the specified formula and choosing splits from the terms of the right-hand side. Numeric variables are divided into X < a and X > a. The split that maximizes the reduction in impurity is chosen, the dataset split and the process repeated. Splitting continues until the terminal nodes are too small or too few to be split. Take a look at the following table for basic information on this package:
|Date||January 21, 2016|
|Title||Classification and Regression Trees|
To perform a regression tree example, begin with the data. Use the mtcars dataset contained in the datasets package. You can extract the data from the 1974 Motor Trend US magazine. It comprises fuel consumption and ten aspects of automobile design and performance for 32 automobiles (1973–74 models). The mtcars dataset also contains gas mileage, horsepower, and other information for 32 vehicles. It is a data frame with 32 observations on the following 11 variables:
- mpg: Miles per gallon
- cyl: Number of cylinders
- disp: Engine displacement (cubic inches)
- hp: Engine horsepower
- drat: Rear axle ratio
- wt: Weight (1000lbs)
- qsec: 1/4mile time
- vs: V/S
- am: Transmission (0 = automatic, 1 = manual)
- gear: Number of forward gears
- carb: Number of carburetors
The fuel consumption of vehicles has always been studied by major manufacturers of the entire planet. In an era characterized by oil refueling problems and even greater air pollution problems, fuel consumption by vehicles has become a key factor. In this example, you’ll build a regression tree with the purpose of predicting the fuel consumption of vehicles according to certain characteristics.
The analysis begins by uploading the dataset:
The dataset is contained in the datasets package; to load it, use the data() function. To display a compact summary of the dataset simply type:
The results are shown as follows:
> str(mtcars)' data.frame': 32 obs. of 11 variables: $ mpg : num 21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ... $ cyl : num 6 6 4 6 8 6 8 4 4 6 ... $ disp: num 160 160 108 258 360 ... $ hp : num 110 110 93 110 175 105 245 62 95 123 ... $ drat: num 3.9 3.9 3.85 3.08 3.15 2.76 3.21 3.69 3.92 3.92 ... $ wt : num 2.62 2.88 2.32 3.21 3.44 ... $ qsec: num 16.5 17 18.6 19.4 17 ... $ vs : num 0 0 1 1 0 1 0 1 1 1 ... $ am : num 1 1 1 0 0 0 0 0 0 0 ... $ gear: num 4 4 4 3 3 3 3 4 4 4 ... $ carb: num 4 4 1 1 2 1 4 2 2 4 ...
You have thus confirmed that these are 11 numeric variables with 32 observations. To extract more information, use the summary() function:
> summary(mtcars) mpg cyl disp hp Min. :10.40 Min. :4.000 Min. : 71.1 Min. : 52.0 1st Qu.:15.43 1st Qu.:4.000 1st Qu.:120.8 1st Qu.: 96.5 Median :19.20 Median :6.000 Median :196.3 Median :123.0 Mean :20.09 Mean :6.188 Mean :230.7 Mean :146.7 3rd Qu.:22.80 3rd Qu.:8.000 3rd Qu.:326.0 3rd Qu.:180.0 Max. :33.90 Max. :8.000 Max. :472.0 Max. :335.0 drat wt qsec vs Min. :2.760 Min. :1.513 Min. :14.50 Min. :0.0000 1st Qu.:3.080 1st Qu.:2.581 1st Qu.:16.89 1st Qu.:0.0000 Median :3.695 Median :3.325 Median :17.71 Median :0.0000 Mean :3.597 Mean :3.217 Mean :17.85 Mean :0.4375 3rd Qu.:3.920 3rd Qu.:3.610 3rd Qu.:18.90 3rd Qu.:1.0000 Max. :4.930 Max. :5.424 Max. :22.90 Max. :1.0000 am gear carb Min. :0.0000 Min. :3.000 Min. :1.000 1st Qu.:0.0000 1st Qu.:3.000 1st Qu.:2.000 Median :0.0000 Median :4.000 Median :2.000 Mean :0.4062 Mean :3.688 Mean :2.812 3rd Qu.:1.0000 3rd Qu.:4.000 3rd Qu.:4.000 Max. :1.0000 Max. :5.000 Max. :8.000
Before starting with data analysis, conduct an exploratory analysis to understand how the data is distributed and extract preliminary knowledge. First, try to find out whether the variables are related to each other. You can do this using the pairs() function to create a matrix of sub-axes containing scatter plots of the columns of a matrix. To reduce the number of plots in the matrix, limit your analysis to just four predictors: cylinders, displacement, horsepower, and weight. The target is the mpg variable that contains the miles per gallon of 32 sample cars:
To specify the response and predictors, the formula argument is used. Each term gives a separate variable in the pairs plot, so terms must be numeric vectors. The response is interpreted as another variable, but not treated specially. The following figure shows a scatter plot matrix:
By observing the plots in the first line, it can be noted that fuel consumption increases as the number of cylinders, the engine displacement, the horsepower, and the weight of the vehicle increases.
At this point, you can use the tree() function to build the regression tree. First, install the tree package. To install a library that is not present in the initial distribution of R, you must use the install.package function. This is the main function used to install packages. It takes a vector of names and a destination library, downloads the packages from the repositories, and installs them.
Now, load the library through the library command:
You can use the tree() function that builds a regression tree:
RTModel <- tree(mpg~.,data = mtcars)
Only two arguments are passed—a formula and the dataset name. The left-hand side of the formula (response) should be a numerical vector when a regression tree is fitted. The right-hand side should be a series of numeric variables separated by +; there should be no interaction terms. Both . and – are allowed; regression trees can have offset terms.
Here are the results:
> RTModel node), split, n, deviance, yval * denotes terminal node 1) root 32 1126.000 20.09 2) wt < 2.26 6 44.550 30.07 * 3) wt > 2.26 26 346.600 17.79 6) cyl < 7 12 42.120 20.92 12) cyl < 5 5 5.968 22.58 * 13) cyl > 5 7 12.680 19.74 * 7) cyl > 7 14 85.200 15.10 14) hp < 192.5 7 16.590 16.79 * 15) hp > 192.5 7 28.830 13.41 *
These results describe exactly each node in the tree. Information on each node is presented in an indented format. It is used to indicate the tree topology; that is, it indicates the parent and child relationships (also referred to as primary and secondary splits). Also, to denote a terminal node, an asterisk (*) is used.
In the tree sequence, nodes are labeled with unique numbers. These numbers are generated by the following formula: the child nodes of a node x are always numbered 2*x (left child) and 2*x+1 (right child). The root node is numbered as one. The following figure explains this rule:
From the analysis of the results, you can see a selection of variables; in fact, between the ten available variables, only three—wt, cyl, and hp—were selected. More information can be obtained from the summary() function:
> summary(RTModel) Regression tree: tree(formula = mpg ~ ., data = mtcars) Variables actually used in tree construction:  "wt" "cyl" "hp" Number of terminal nodes: 5 Residual mean deviance: 4.023 = 108.6 / 27 Distribution of residuals: Min. 1st Qu. Median Mean 3rd Qu. Max. -4.067 -1.361 0.220 0.000 1.361 3.833
The output of summary() indicates that only three of the variables have been used in constructing the tree. In the context of a regression tree, the deviance is simply the sum of squared errors for the tree. Now, you can plot the regression tree:
The first one plots the regression tree, while the second one adds the text on the branches to explain the workflow. The resulting plot is shown in the following figure:
Now look at what the regression tree has returned. The first thing that seems obvious is a sort of indication of the importance of variables. The choice of three predictors for the ten available variables already makes you realize that these three are the ones that most affect the fuel consumption of cars inserted in the dataset.
Now, you can add that the most important predictor is the weight of the vehicle; in fact, a weight less than 2.26 lbs leads you to a terminal knot, which gives a consumption estimate (30.07 miles/(US) gallon). You can then see this immediately after you find the number of cylinders of the engine and the horsepower.
If you found this article interesting, you can explore Giuseppe Ciaburro’s Regression Analysis with R to build effective regression models in R to extract valuable insights from real data. This book will give you a rundown explaining what regression analysis is, explaining to you the process from scratch.