Austin David Brown
Google scholar   |   Curriculum vitae   |   Github   |   arXiv   |   Linkedin   |   Posts   |   Categories

Classification Decision Trees in Python


$ \newcommand{\norm}[1]{\left\lVert#1\right\rVert} $ Classification decision trees use the binary tree data structure to recursively split the feature space and then fit a weight at each leaf of the tree with a classification prediction. A tree can be represented by $K$ leaf nodes dividing the feature space into regions $(R_k)$ and weights $(w_k)$ so that the tree is \[ T(x) = \sum_{k = 1}^K w_k I(x \in R_k). \] Decision trees are a supervised learning method. Let $\mathcal{D}_n = (X_i, y_i)_{i = 1}^n$ be the training data. We need to establish how to compute a weight at each leaf node, and determine a splitting measure for the tree. We can predict each weight by using the largest posterior probability over the $C$ different classes \[ \max_{i \in 1, \ldots, C} P(Y = c_i | X = x, \mathcal{D}_n). \] computed over each leaf node. One splitting measure is the information gain defined as the entropy of the parent node - the average entropy of the leaves. Formally, the information gain is \[ H(y_{Parent}) - \frac{1}{N} \left( N_{Left} H(y_{Left}) + N_{Right} H(y_{Right}) \right). \] We can use a greedy algorithm to search through the feature space evaluating the information gain to find the optimal split. This computation is expensive. We can combine this into an algorithm for growing the tree recursively. Algorithm: Grow Tree function $growTree(\mathcal{D}, depth)$ ; node $\leftarrow$ empty node node.weight $\leftarrow$ predict the weight $j^*, s^*, \mathcal{D}_{Left}, \mathcal{D}_{Right} \leftarrow$ greedy split of $\mathcal{D}$  if we should branch the tree then  node.feature_index $\leftarrow j^*$  node.split_value $\leftarrow s^*$  node.left $\leftarrow growTree(\mathcal{D}_{Left}, depth + 1)$  node.right $\leftarrow growTree(\mathcal{D}_{Right}, depth + 1)$  return node

We implement the classification decision tree in Python
import numpy as np
import matplotlib.pyplot as plt

class DTree():
  @staticmethod
  def entropy(v):
    S, counts = np.unique(v, return_counts = True)
    N = v.shape[0]
    p = counts / N
    return -np.sum(np.log2(p) * p)

  @staticmethod
  def split_data(X, y, feature_index, feature_value):
    return {
      "I_left": np.where(X[:, feature_index] <= feature_value)[0],
      "I_right": np.where(X[:, feature_index] > feature_value)[0],
    }

  # Greedy algorithm for finding the best split using information gain
  # We look for the split with the best increase in information gain
  @staticmethod
  def greedy_best_split(X, y):
    best_feature_index = 0
    best_split_value = 0
    best_IG = 0
    best_split = {
      "I_left": np.array([]),
      "I_right": np.array([]),
    }

    n_features = X.shape[1]
    parent_entropy = DTree.entropy(y)
    N = y.shape[0]
    for feature_index in range(0, n_features):
      split_values = np.unique(X[:, feature_index])
      for split_value in split_values:
        split = DTree.split_data(X, y, feature_index, split_value)

        # Compute the information gain
        N_left = split["I_left"].shape[0]
        N_right = split["I_right"].shape[0]
        IG = parent_entropy - 1/N * (N_left * DTree.entropy(y[split["I_left"]]) + N_right * DTree.entropy(y[split["I_right"]]))

        # Update if the information gain is the largest so far
        if IG >= best_IG:
          best_feature_index = feature_index
          best_split_value = split_value
          best_split = split
          best_IG = IG
    return best_IG, best_feature_index, best_split_value, best_split

  @staticmethod
  def fit_tree(X, y, depth = 1, 
               max_depth = 100, tolerance = 10**(-3)):
    node = {}

    # Set weight with the mode
    S_y, counts = np.unique(y, return_counts = True)
    node["w"] = S_y[np.argmax(counts)] # mode

    node["left"] = None
    node["right"] = None

    # If we can split, find the best split by greedy algorithm
    if y.shape[0] >= 2:
      IG, feature_index, split_value, split = DTree.greedy_best_split(X, y)
      # If there is a greedy split and the stopping criterion is not met, branch 2 times
      if split["I_left"].shape[0] > 0 and split["I_right"].shape[0] > 0 and IG >= tolerance and depth < max_depth:
        node["information_gain"] = IG
        node["feature_index"] = feature_index
        node["split_value"] = split_value

        node["left"] = DTree.fit_tree(X[split["I_left"]], y[split["I_left"]], depth = depth + 1, max_depth = max_depth, tolerance = tolerance)
        node["right"] = DTree.fit_tree(X[split["I_right"]], y[split["I_right"]], depth = depth + 1, max_depth = max_depth, tolerance = tolerance) 
    return node

  ###
  # Predict
  ###
  @staticmethod
  def predict_one(node, x):
    if node["left"] == None:
      return node["w"]
    else:
      if x[node["feature_index"]] <= node["split_value"]:
        return DTree.predict_one(node["left"], x)
      else:
        return DTree.predict_one(node["right"], x)

  @staticmethod
  def predict(node, X):
    n_samples = X.shape[0]
    predictions = np.zeros(n_samples)
    for i in range(0, n_samples):
      predictions[i] = DTree.predict_one(node, X[i])
    return predictions

  @staticmethod
  def print_tree(node, depth = 0):
    if node["left"] == None:
      print(f'{depth * "  "}weight: {node["w"]}')
    else:
      print(f'{depth * "  "}X{node["feature_index"]} <= {node["split_value"]}')
      DTree.print_tree(node["left"], depth + 1)
      DTree.print_tree(node["right"], depth + 1)
We will use the test set from the UCI optical handwritten digits dataset, plot a few images, and do a train/test split.
test = np.loadtxt("data/optdigits_test.txt", delimiter = ",")
X = test[:, 0:64]
y = test[:, 64]

# Train/test split
n_samples = X.shape[0]
n_TRAIN = int(.75 * n_samples)
I = np.arange(0, n_samples)
TRAIN = np.random.choice(I, n_TRAIN, replace = False)
TEST = np.setdiff1d(I, TRAIN)
X_train = X[TRAIN, :]
y_train = y[TRAIN]
X_test = X[TEST, :]
y_test = y[TEST]

# Plot some of the digits
fig = plt.figure(figsize=(8, 6))
fig.tight_layout()
for i in range(0, 20):
    ax = fig.add_subplot(5, 5, i + 1)
    ax.imshow(X[i].reshape((8,8)), cmap = "Greys", vmin = 0, vmax = 16)
plt.show()

We train the decision tree and report the training and test accuracy.
tree = DTree.fit_tree(X_train, y_train, max_depth = 100, tolerance = 10**(-3))

print("Train accuracy:", 1/X_train.shape[0] * np.sum(DTree.predict(tree, X_train) == y_train))
print("Test accuracy", 1/X_test.shape[0] * np.sum(DTree.predict(tree, X_test) == y_test))
Train accuracy: 1.0
Test accuracy 0.8644444444444445

References.

http://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits

Richard O. Duda, Peter E. Hart, and David G. Stork. 2000. Pattern Classification (2nd Edition). Wiley-Interscience, New York, NY, USA.

Kevin P. Murphy. 2012. Machine Learning: A Probabilistic Perspective. The MIT Press.

https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/