We implement the regression decision tree in Python.
import numpy as np
import matplotlib.pyplot as plt
class RegTree():
@staticmethod
def mse(v):
return np.mean(np.square(v - np.mean(v)))
@staticmethod
def split_data(X, y, feature_index, feature_value):
return {
"I_left": np.where(X[:, feature_index] <= feature_value)[0],
"I_right": np.where(X[:, feature_index] > feature_value)[0],
}
# Greedy algorithm for finding the best split
@staticmethod
def greedy_best_split(X, y):
best_feature_index = 0
best_split_value = 0
best_dloss = 0
best_split = {
"I_left": np.array([]),
"I_right": np.array([]),
}
n_features = X.shape[1]
parent_mse = RegTree.mse(y)
N = y.shape[0]
for feature_index in range(0, n_features):
split_values = np.unique(X[:, feature_index])
for split_value in split_values:
split = RegTree.split_data(X, y, feature_index, split_value)
# If there is a split
if split["I_left"].shape[0] > 0 and split["I_right"].shape[0] > 0:
# Compute the change in loss
N_left = split["I_left"].shape[0]
N_right = split["I_right"].shape[0]
dloss = parent_mse - 1/N * (N_left * RegTree.mse(y[split["I_left"]]) + N_right * RegTree.mse(y[split["I_right"]]))
# Update if the change in loss is the largest so far
if dloss >= best_dloss:
best_feature_index = feature_index
best_split_value = split_value
best_split = split
best_dloss = dloss
return best_dloss, best_feature_index, best_split_value, best_split
@staticmethod
def fit_tree(X, y, depth = 1,
max_depth = 100, tolerance = 10**(-3)):
node = {}
# Predict with the mean
node["w"] = np.mean(y)
node["left"] = None
node["right"] = None
# If we can split, find the best split by greedy algorithm
if y.shape[0] >= 2:
dloss, feature_index, split_value, split = RegTree.greedy_best_split(X, y)
# If there is a greedy split and the stopping criterion is not met, branch 2 times
if split["I_left"].shape[0] > 0 and split["I_right"].shape[0] > 0 and dloss >= tolerance and depth < max_depth:
node["dloss"] = dloss
node["feature_index"] = feature_index
node["split_value"] = split_value
node["left"] = RegTree.fit_tree(X[split["I_left"]], y[split["I_left"]], depth = depth + 1, max_depth = max_depth, tolerance = tolerance)
node["right"] = RegTree.fit_tree(X[split["I_right"]], y[split["I_right"]], depth = depth + 1, max_depth = max_depth, tolerance = tolerance)
return node
###
# Predict
###
@staticmethod
def predict_one(node, x):
if node["left"] == None:
return node["w"]
else:
if x[node["feature_index"]] <= node["split_value"]:
return RegTree.predict_one(node["left"], x)
else:
return RegTree.predict_one(node["right"], x)
@staticmethod
def predict(node, X):
n_samples = X.shape[0]
predictions = np.zeros(n_samples)
for i in range(0, n_samples):
predictions[i] = RegTree.predict_one(node, X[i])
return predictions
@staticmethod
def print_tree(node, depth = 0):
if node["left"] == None:
print(f'{depth * " "}weight: {node["w"]}')
else:
print(f'{depth * " "}X{node["feature_index"]} <= {node["split_value"]}')
RegTree.print_tree(node["left"], depth + 1)
RegTree.print_tree(node["right"], depth + 1)
We generate some regression data and do a train/test split.
n_samples = 100 n_features = 10 intercept = 5 * np.ones(n_samples) B = 3 * np.ones(n_features) X = np.zeros((n_samples, n_features)) for i in range(0, n_samples): X[i, :] = np.random.multivariate_normal(np.zeros(n_features), 10 * np.identity(n_features)) e = np.random.multivariate_normal(np.zeros(n_samples), np.identity(n_samples)) y = intercept + X @ B + e # Train/test split n_samples = X.shape[0] n_TRAIN = int(.75 * n_samples) I = np.arange(0, n_samples) TRAIN = np.random.choice(I, n_TRAIN, replace = False) TEST = np.setdiff1d(I, TRAIN) X_train = X[TRAIN, :] y_train = y[TRAIN] X_test = X[TEST, :] y_test = y[TEST]We train the decision tree and report the training and test mean squared error.
tree = RegTree.fit_tree(X_train, y_train, max_depth = 100, tolerance = 10**(-3))
print("Train MSE:", 1/X_train.shape[0] * np.sum(np.square(y_train - RegTree.predict(tree, X_train))))
print("Train MSE:", 1/X_test.shape[0] * np.sum(np.square(y_test - RegTree.predict(tree, X_test))))
Train MSE: 0.0 Train MSE: 1045.3882889479746
References.
Kevin P. Murphy. 2012. Machine Learning: A Probabilistic Perspective. The MIT Press.https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
Tweets by austindavbrown