Build a Decision Tree in Polars from Scratch
Explore decision trees with polars backendPhoto by Leonard Laub on UnsplashDecision tree algorithms have always fascinated me. They are easy to implement and achieve good results on various classification and regression tasks. Combined with boosting, decision trees are still state-of-the-art in many applications.Frameworks such as sklearn, lightgbm, xgboost and catboost have done a very good job until today. However, in the past few months, I have been missing support for arrow datasets. While lightgbm has recently added support for that, it is still missing in most other frameworks. The arrow data format could be a perfect match for decision trees since it has a columnar structure optimized for efficient data processing. Pandas already added support for that and also polars uses the advantages.Polars has shown some significant performance advantages over most other data frameworks. It uses the data efficiently and avoids copying the data unnecessarily. It also provides a streaming engine that allows the processing of larger data than memory. This is why I decided to use polars as a backend for building a decision tree from scratch.The goal is to explore the advantages of using polars for decision trees in terms of memory and runtime. And, of course, learning more about polars, efficiently defining expressions, and the streaming engine.The code for the implementation can be found in this repository.Code overviewTo get a first overview of the code, I will show the structure of the DecisionTreeClassifier first:import picklefrom typing import Iterable, List, Unionimport polars as plclass DecisionTreeClassifier: def __init__(self, streaming=False, max_depth=None, categorical_columns=None): ... def save_model(self, path: str) -> None: ... def load_model(self, path: str) -> None: ... def apply_categorical_mappings(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> Union[pl.DataFrame, pl.LazyFrame]: ... def fit(self, data: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> None: ... def predict_many(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> List[Union[int, float]]: ... def predict(self, data: Iterable[dict]): ... def get_majority_class(self, df: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> str: ... def _build_tree( self, data: Union[pl.DataFrame, pl.LazyFrame], feature_names: list[str], target_name: str, unique_targets: list[int], depth: int, ) -> dict: ...The first important thing can be seen in the imports. It was important for me to keep the import section clean and with as few dependencies as possible. This was successful with only having dependencies to polars, pickle, and typing.The init method allows to define if the polars streaming engine should be used. Also, the max_depth of the tree can be set here. Another feature in the definition of categorical columns. These are handled in a different way than numerical features using a target encoding.It is possible to save and load the decision tree model. It is represented as a nested dict and can be saved to disk as a pickled file.The polars magic happens in the fit() and build_tree() methods. These accept both LazyFrames and DataFrames to have support for in-memory processing and streaming.There are two prediction methods available, predict() and predict_many(). The predict() method can be used on a small example size, and the data needs to be provided as a dict. If we have a big test set, it is more efficient to use the predict_many() method. Here, the data can be provided as a polars DataFrame or LazyFrame.Fitting the treeTo train the decision tree classifier, the fit() method needs to be used.def fit(self, data: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> None: """ Fit method to train the decision tree. :param data: Polars DataFrame or LazyFrame containing the training data. :param target_name: Name of the target column """ columns = data.collect_schema().names() feature_names = [col for col in columns if col != target_name] # Shrink dtypes data = pl.col(target_name).cast(pl.UInt64).shrink_dtype().alias(target_name) ) # Prepare categorical columns with target encoding if self.categorical_columns: categorical_mappings = {} for categorical_column in self.categorical_columns: categorical_mappings[categorical_column] = { value: index for index, value in enumerate( data.lazy() .group_by(categorical_column) .agg(pl.col(target_name).mean().alias("avg")) .sort("avg") .collect(streaming=self.streaming)[categorical_column] ) } self.categorical_mappings = categorical_mappings data = self.apply_categorical_mappings(data)
Explore decision trees with polars backend
Decision tree algorithms have always fascinated me. They are easy to implement and achieve good results on various classification and regression tasks. Combined with boosting, decision trees are still state-of-the-art in many applications.
Frameworks such as sklearn, lightgbm, xgboost and catboost have done a very good job until today. However, in the past few months, I have been missing support for arrow datasets. While lightgbm has recently added support for that, it is still missing in most other frameworks. The arrow data format could be a perfect match for decision trees since it has a columnar structure optimized for efficient data processing. Pandas already added support for that and also polars uses the advantages.
Polars has shown some significant performance advantages over most other data frameworks. It uses the data efficiently and avoids copying the data unnecessarily. It also provides a streaming engine that allows the processing of larger data than memory. This is why I decided to use polars as a backend for building a decision tree from scratch.
The goal is to explore the advantages of using polars for decision trees in terms of memory and runtime. And, of course, learning more about polars, efficiently defining expressions, and the streaming engine.
The code for the implementation can be found in this repository.
Code overview
To get a first overview of the code, I will show the structure of the DecisionTreeClassifier first:
import pickle
from typing import Iterable, List, Union
import polars as pl
class DecisionTreeClassifier:
def __init__(self, streaming=False, max_depth=None, categorical_columns=None):
def save_model(self, path: str) -> None:
def load_model(self, path: str) -> None:
def apply_categorical_mappings(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> Union[pl.DataFrame, pl.LazyFrame]:
def fit(self, data: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> None:
def predict_many(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> List[Union[int, float]]:
def predict(self, data: Iterable[dict]):
def get_majority_class(self, df: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> str:
def _build_tree(
data: Union[pl.DataFrame, pl.LazyFrame],
feature_names: list[str],
target_name: str,
unique_targets: list[int],
depth: int,
) -> dict:
The first important thing can be seen in the imports. It was important for me to keep the import section clean and with as few dependencies as possible. This was successful with only having dependencies to polars, pickle, and typing.
The init method allows to define if the polars streaming engine should be used. Also, the max_depth of the tree can be set here. Another feature in the definition of categorical columns. These are handled in a different way than numerical features using a target encoding.
It is possible to save and load the decision tree model. It is represented as a nested dict and can be saved to disk as a pickled file.
The polars magic happens in the fit() and build_tree() methods. These accept both LazyFrames and DataFrames to have support for in-memory processing and streaming.
There are two prediction methods available, predict() and predict_many(). The predict() method can be used on a small example size, and the data needs to be provided as a dict. If we have a big test set, it is more efficient to use the predict_many() method. Here, the data can be provided as a polars DataFrame or LazyFrame.
Fitting the tree
To train the decision tree classifier, the fit() method needs to be used.
def fit(self, data: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> None:
Fit method to train the decision tree.
:param data: Polars DataFrame or LazyFrame containing the training data.
:param target_name: Name of the target column
columns = data.collect_schema().names()
feature_names = [col for col in columns if col != target_name]
# Shrink dtypes
data =
# Prepare categorical columns with target encoding
if self.categorical_columns:
categorical_mappings = {}
for categorical_column in self.categorical_columns:
categorical_mappings[categorical_column] = {
value: index
for index, value in enumerate(
self.categorical_mappings = categorical_mappings
data = self.apply_categorical_mappings(data)
unique_targets =
if isinstance(unique_targets, pl.LazyFrame):
unique_targets = unique_targets.collect(streaming=self.streaming)
unique_targets = unique_targets[target_name].to_list()
self.tree = self._build_tree(data, feature_names, target_name, unique_targets, depth=0)
It receives a polars LazyFrame or DataFrame that contains all features and the target column. To identify the target column, the target_name needs to be provided.
Polars provides a convenient way to optimize the memory usage of the data.
With that, all columns are selected and evaluated. It will convert the dtype to the smallest possible value.
The categorical encoding
To encode categorical values, a target encoding is used. For that, all instances of a categorical feature will be aggregated, and the average target value will be calculated. Then, the instances are sorted by the average target value, and a rank is assigned. This rank will be used as the representation of the feature value.
Since it is possible to provide polars DataFrames and LazyFrames, I use data.lazy() first. If the given data is a DataFrame, it will be converted to a LazyFrame. If it is already a LazyFrame, it only returns self. With that trick, it is possible to ensure that the data is processed in the same way for LazyFrames and DataFrames and that the collect() method can be used, which is only available for LazyFrames.
To illustrate the outcome of the calculations in the different steps of the fitting process, I apply it to a dataset for heart disease prediction. It can be found on Kaggle and is published under the Database Contents License.
Here is an example of the categorical feature representation for the glucose levels:
│ rank ┆ gluc ┆ avg │
│ --- ┆ --- ┆ --- │
│ u32 ┆ i8 ┆ f64 │
│ 0 ┆ 1 ┆ 0.476139 │
│ 1 ┆ 2 ┆ 0.586319 │
│ 2 ┆ 3 ┆ 0.620972 │
For each of the glucose levels, the probability of having a heart disease is calculated. This is sorted and then ranked so that each of the levels is mapped to a rank value.
Getting the target values
As the last part of the fit() method, the unique target values are determined.
unique_targets =
if isinstance(unique_targets, pl.LazyFrame):
unique_targets = unique_targets.collect(streaming=self.streaming)
unique_targets = unique_targets[target_name].to_list()
self.tree = self._build_tree(data, feature_names, target_name, unique_targets, depth=0)
This serves as the last preparation before calling the _build_tree() method recursively.
Building the tree
After the data is prepared in the fit() method, the _build_tree() method is called. This is done recursively until a stopping criterion is met, e.g., the max depth of the tree is reached. The first call is executed from the fit() method with a depth of zero.
def _build_tree(
data: Union[pl.DataFrame, pl.LazyFrame],
feature_names: list[str],
target_name: str,
unique_targets: list[int],
depth: int,
) -> dict:
Builds the decision tree recursively.
If max_depth is reached, returns a leaf node with the majority class.
Otherwise, finds the best split and creates internal nodes for left and right children.
:param data: The dataframe to evaluate.
:param feature_names: Name of the feature columns.
:param target_name: Name of the target column.
:param unique_targets: unique target values.
:param depth: The current depth of the tree.
:return: A dictionary representing the node.
if self.max_depth is not None and depth >= self.max_depth:
return {"type": "leaf", "value": self.get_majority_class(data, target_name)}
# Make data lazy here to avoid that it is evaluated in each loop iteration.
data = data.lazy()
# Evaluate entropy per feature:
information_gain_dfs = []
for feature_name in feature_names:
feature_data =[feature_name, target_name]).filter(pl.col(feature_name).is_not_null())
feature_data = feature_data.rename({feature_name: "feature_value"})
# No streaming (yet)
information_gain_df = (
.filter(pl.col(target_name) == target_value)
for target_value in unique_targets
+ [pl.col(target_name).len().alias("count_examples")]
for target_value in unique_targets
+ [
for target_value in unique_targets
+ [
+ [
# From previous select
# At least one example available
> pl.col("cum_sum_count_examples")
(pl.col(f"cum_sum_class_{target_value}_count") / pl.col("cum_sum_count_examples")).alias(
for target_value in unique_targets
+ [
(pl.col(f"sum_class_{target_value}_count") - pl.col(f"cum_sum_class_{target_value}_count"))
/ (pl.col("sum_count_examples") - pl.col("cum_sum_count_examples"))
for target_value in unique_targets
+ [
(pl.col(f"sum_class_{target_value}_count") / pl.col("sum_count_examples")).alias(
for target_value in unique_targets
+ [
# From previous select
* pl.sum_horizontal(
* pl.col(f"left_proportion_class_{target_value}").log(base=2)
for target_value in unique_targets
* pl.sum_horizontal(
* pl.col(f"right_proportion_class_{target_value}").log(base=2)
for target_value in unique_targets
* pl.sum_horizontal(
* pl.col(f"parent_proportion_class_{target_value}").log(base=2)
for target_value in unique_targets
# From previous select
pl.col("cum_sum_count_examples") / pl.col("sum_count_examples") * pl.col("left_entropy")
+ (pl.col("sum_count_examples") - pl.col("cum_sum_count_examples"))
/ pl.col("sum_count_examples")
* pl.col("right_entropy")
# From previous select
(pl.col("parent_entropy") - pl.col("child_entropy")).alias("information_gain"),
# From previous select
.sort("information_gain", descending=True)
if isinstance(information_gain_dfs[0], pl.LazyFrame):
information_gain_dfs = pl.collect_all(information_gain_dfs, streaming=self.streaming)
information_gain_dfs = pl.concat(information_gain_dfs, how="vertical_relaxed").sort(
"information_gain", descending=True
information_gain = 0
if len(information_gain_dfs) > 0:
best_params = information_gain_dfs.row(0, named=True)
information_gain = best_params["information_gain"]
if information_gain > 0:
left_mask =["feature"]) <= best_params["feature_value"])
if isinstance(left_mask, pl.LazyFrame):
left_mask = left_mask.collect(streaming=self.streaming)
left_mask = left_mask["filter"]
# Split data
left_df = data.filter(left_mask)
right_df = data.filter(~left_mask)
left_subtree = self._build_tree(left_df, feature_names, target_name, unique_targets, depth + 1)
right_subtree = self._build_tree(right_df, feature_names, target_name, unique_targets, depth + 1)
if isinstance(data, pl.LazyFrame):
target_distribution = (
target_distribution = data[target_name].value_counts().sort(target_name)["count"].to_list()
return {
"type": "node",
"feature": best_params["feature"],
"threshold": best_params["feature_value"],
"information_gain": best_params["information_gain"],
"entropy": best_params["parent_entropy"],
"target_distribution": target_distribution,
"left": left_subtree,
"right": right_subtree,
return {"type": "leaf", "value": self.get_majority_class(data, target_name)}
This method is the heart of building the trees and I will explain it step by step. First, when entering the method, it is checked if the max depth stopping criterion is met.
if self.max_depth is not None and depth >= self.max_depth:
return {"type": "leaf", "value": self.get_majority_class(data, target_name)}
If the current depth is equal to or greater than the max_depth, a node of the type leaf will be returned. The value of the leaf corresponds to the majority class of the data. This is calculated as follows:
def get_majority_class(self, df: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> str:
Returns the majority class of a dataframe.
:param df: The dataframe to evaluate.
:param target_name: Name of the target column.
:return: majority class.
majority_class = df.group_by(target_name).len().filter(pl.col("len") == pl.col("len").max()).select(target_name)
if isinstance(majority_class, pl.LazyFrame):
majority_class = majority_class.collect(streaming=self.streaming)
return majority_class[target_name][0]
To get the majority class, the count of rows per target is determined by grouping over the target column and aggregating with len(). The target instance, which is present in most of the rows, is returned as the majority class.
Information Gain as Splitting Criteria
To find a good split of the data, the information gain is used.

To get the information gain, the parent entropy and child entropy need to be calculated.

A good explanation of the interpretation of information gain can be found here.
Calculating The Information Gain in Polars
The information gain is calculated for each feature value that is present in a feature column.
information_gain_df = (
.filter(pl.col(target_name) == target_value)
for target_value in unique_targets
+ [pl.col(target_name).len().alias("count_examples")]
The feature values are grouped, and the count of each of the target values is assigned to it. Additionally, the total count of rows for that feature value is saved as count_examples. In the last step, the data is sorted by feature_value. This is needed to calculate the splits in the next step.
For the heart disease dataset, after the first calculation step, the data looks like this:
│ feature_value ┆ class_0_count ┆ class_1_count ┆ count_examples │
│ --- ┆ --- ┆ --- ┆ --- │
│ i8 ┆ u32 ┆ u32 ┆ u32 │
│ 29 ┆ 2 ┆ 0 ┆ 2 │
│ 30 ┆ 1 ┆ 0 ┆ 1 │
│ 39 ┆ 1068 ┆ 331 ┆ 1399 │
│ 40 ┆ 975 ┆ 263 ┆ 1238 │
│ 41 ┆ 1052 ┆ 438 ┆ 1490 │
│ … ┆ … ┆ … ┆ … │
│ 60 ┆ 1054 ┆ 1460 ┆ 2514 │
│ 61 ┆ 695 ┆ 1408 ┆ 2103 │
│ 62 ┆ 566 ┆ 1125 ┆ 1691 │
│ 63 ┆ 572 ┆ 1517 ┆ 2089 │
│ 64 ┆ 479 ┆ 1217 ┆ 1696 │
Here, the feature age_years is processed. Class 0 stands for “no heart disease,” and class 1 stands for “heart disease.” The data is sorted by the age of years feature, and the columns contain the count of class 0, class 1, and the total count of examples with the respective feature value.
In the next step, the cumulative sum over the count of classes is calculated for each feature value.
for target_value in unique_targets
+ [
for target_value in unique_targets
+ [
+ [
# From previous select
# At least one example available
> pl.col("cum_sum_count_examples")
The intuition behind it is that when a split is executed over a specific feature value, it includes the count of target values from smaller feature values. To be able to calculate the proportion, the total sum of the target values is calculated. The same procedure is repeated for count_examples, where the cumulative sum and the total sum are calculated as well.
After the calculation, the data looks like this:
│ cum_sum_clas ┆ cum_sum_cla ┆ sum_class_0 ┆ sum_class_1 ┆ cum_sum_cou ┆ sum_count_e ┆ feature_val │
│ s_0_count ┆ ss_1_count ┆ _count ┆ _count ┆ nt_examples ┆ xamples ┆ ue │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ u32 ┆ u32 ┆ u32 ┆ u32 ┆ u32 ┆ i8 │
│ 3 ┆ 0 ┆ 27717 ┆ 26847 ┆ 3 ┆ 54564 ┆ 29 │
│ 4 ┆ 0 ┆ 27717 ┆ 26847 ┆ 4 ┆ 54564 ┆ 30 │
│ 1097 ┆ 324 ┆ 27717 ┆ 26847 ┆ 1421 ┆ 54564 ┆ 39 │
│ 2090 ┆ 595 ┆ 27717 ┆ 26847 ┆ 2685 ┆ 54564 ┆ 40 │
│ 3155 ┆ 1025 ┆ 27717 ┆ 26847 ┆ 4180 ┆ 54564 ┆ 41 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 24302 ┆ 20162 ┆ 27717 ┆ 26847 ┆ 44464 ┆ 54564 ┆ 59 │
│ 25356 ┆ 21581 ┆ 27717 ┆ 26847 ┆ 46937 ┆ 54564 ┆ 60 │
│ 26046 ┆ 23020 ┆ 27717 ┆ 26847 ┆ 49066 ┆ 54564 ┆ 61 │
│ 26615 ┆ 24131 ┆ 27717 ┆ 26847 ┆ 50746 ┆ 54564 ┆ 62 │
│ 27216 ┆ 25652 ┆ 27717 ┆ 26847 ┆ 52868 ┆ 54564 ┆ 63 │
In the next step, the proportions are calculated for each feature value.
(pl.col(f"cum_sum_class_{target_value}_count") / pl.col("cum_sum_count_examples")).alias(
for target_value in unique_targets
+ [
(pl.col(f"sum_class_{target_value}_count") - pl.col(f"cum_sum_class_{target_value}_count"))
/ (pl.col("sum_count_examples") - pl.col("cum_sum_count_examples"))
for target_value in unique_targets
+ [
(pl.col(f"sum_class_{target_value}_count") / pl.col("sum_count_examples")).alias(
for target_value in unique_targets
+ [
# From previous select
To calculate the proportions, the results from the previous step can be used. For the left proportion, the cumulative sum of each target value is divided by the cumulative sum of the example count. For the right proportion, we need to know how many examples we have on the right side for each target value. That is calculated by subtracting the total sum for the target value from the cumulative sum of the target value. The same calculation is used to determine the total count of examples on the right side by subtracting the sum of the example count from the cumulative sum of the example count. Additionally, the parent proportion is calculated. This is done by dividing the sum of the target values counts by the total count of examples.
This is the result data after this step:
│ left_prop ┆ left_prop ┆ right_pro ┆ right_pro ┆ … ┆ parent_pr ┆ cum_sum_c ┆ sum_count ┆ feature_ │
│ ortion_cl ┆ ortion_cl ┆ portion_c ┆ portion_c ┆ ┆ oportion_ ┆ ount_exam ┆ _examples ┆ value │
│ ass_0 ┆ ass_1 ┆ lass_0 ┆ lass_1 ┆ ┆ class_1 ┆ ples ┆ --- ┆ --- │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ u32 ┆ i8 │
│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ u32 ┆ ┆ │
│ 1.0 ┆ 0.0 ┆ 0.506259 ┆ 0.493741 ┆ … ┆ 0.493714 ┆ 3 ┆ 54564 ┆ 29 │
│ 1.0 ┆ 0.0 ┆ 0.50625 ┆ 0.49375 ┆ … ┆ 0.493714 ┆ 4 ┆ 54564 ┆ 30 │
│ 0.754902 ┆ 0.245098 ┆ 0.499605 ┆ 0.500395 ┆ … ┆ 0.493714 ┆ 1428 ┆ 54564 ┆ 39 │
│ 0.765596 ┆ 0.234404 ┆ 0.492739 ┆ 0.507261 ┆ … ┆ 0.493714 ┆ 2709 ┆ 54564 ┆ 40 │
│ 0.741679 ┆ 0.258321 ┆ 0.486929 ┆ 0.513071 ┆ … ┆ 0.493714 ┆ 4146 ┆ 54564 ┆ 41 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 0.545735 ┆ 0.454265 ┆ 0.333563 ┆ 0.666437 ┆ … ┆ 0.493714 ┆ 44419 ┆ 54564 ┆ 59 │
│ 0.539065 ┆ 0.460935 ┆ 0.305025 ┆ 0.694975 ┆ … ┆ 0.493714 ┆ 46922 ┆ 54564 ┆ 60 │
│ 0.529725 ┆ 0.470275 ┆ 0.297071 ┆ 0.702929 ┆ … ┆ 0.493714 ┆ 49067 ┆ 54564 ┆ 61 │
│ 0.523006 ┆ 0.476994 ┆ 0.282551 ┆ 0.717449 ┆ … ┆ 0.493714 ┆ 50770 ┆ 54564 ┆ 62 │
│ 0.513063 ┆ 0.486937 ┆ 0.296188 ┆ 0.703812 ┆ … ┆ 0.493714 ┆ 52859 ┆ 54564 ┆ 63 │
Now that the proportions are available, the entropy can be calculated.
* pl.sum_horizontal(
* pl.col(f"left_proportion_class_{target_value}").log(base=2)
for target_value in unique_targets
* pl.sum_horizontal(
* pl.col(f"right_proportion_class_{target_value}").log(base=2)
for target_value in unique_targets
* pl.sum_horizontal(
* pl.col(f"parent_proportion_class_{target_value}").log(base=2)
for target_value in unique_targets
# From previous select
For the calculation of the entropy, Equation 2 is used. The left entropy is calculated using the left proportion, and the right entropy uses the right proportion. For the parent entropy, the parent proportion is used. In this implementation, pl.sum_horizontal() is used to calculate the sum of the proportions to make use of possible optimizations from polars. This can also be replaced with the python-native sum() method.
The data with the entropy values look as follows:
│ left_entropy ┆ right_entropy ┆ parent_entropy ┆ cum_sum_count_e ┆ sum_count_exam ┆ feature_value │
│ --- ┆ --- ┆ --- ┆ xamples ┆ ples ┆ --- │
│ f64 ┆ f64 ┆ f64 ┆ --- ┆ --- ┆ i8 │
│ ┆ ┆ ┆ u32 ┆ u32 ┆ │
│ -0.0 ┆ 0.999854 ┆ 0.999853 ┆ 3 ┆ 54564 ┆ 29 │
│ -0.0 ┆ 0.999854 ┆ 0.999853 ┆ 4 ┆ 54564 ┆ 30 │
│ 0.783817 ┆ 1.0 ┆ 0.999853 ┆ 1427 ┆ 54564 ┆ 39 │
│ 0.767101 ┆ 0.999866 ┆ 0.999853 ┆ 2694 ┆ 54564 ┆ 40 │
│ 0.808516 ┆ 0.999503 ┆ 0.999853 ┆ 4177 ┆ 54564 ┆ 41 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 0.993752 ┆ 0.918461 ┆ 0.999853 ┆ 44483 ┆ 54564 ┆ 59 │
│ 0.995485 ┆ 0.890397 ┆ 0.999853 ┆ 46944 ┆ 54564 ┆ 60 │
│ 0.997367 ┆ 0.880977 ┆ 0.999853 ┆ 49106 ┆ 54564 ┆ 61 │
│ 0.99837 ┆ 0.859431 ┆ 0.999853 ┆ 50800 ┆ 54564 ┆ 62 │
│ 0.999436 ┆ 0.872346 ┆ 0.999853 ┆ 52877 ┆ 54564 ┆ 63 │
Almost there! The final step is missing, which is calculating the child entropy and using that to get the information gain.
pl.col("cum_sum_count_examples") / pl.col("sum_count_examples") * pl.col("left_entropy")
+ (pl.col("sum_count_examples") - pl.col("cum_sum_count_examples"))
/ pl.col("sum_count_examples")
* pl.col("right_entropy")
# From previous select
(pl.col("parent_entropy") - pl.col("child_entropy")).alias("information_gain"),
# From previous select
.sort("information_gain", descending=True)
For the child entropy, the left and right entropy are weighted by the count of examples for the feature values. The sum of both weighted entropy values is used as child entropy. To calculate the information gain, we simply need to subtract the child entropy from the parent entropy, as can be seen in Equation 1. The best feature value is determined by sorting the data by information gain and selecting the first row. It is appended to a list that gathers all the best feature values from all features.
Before applying .head(1), the data looks as follows:
│ information_gain ┆ parent_entropy ┆ feature_value │
│ --- ┆ --- ┆ --- │
│ f64 ┆ f64 ┆ i8 │
│ 0.028388 ┆ 0.999928 ┆ 54 │
│ 0.027719 ┆ 0.999928 ┆ 52 │
│ 0.027283 ┆ 0.999928 ┆ 53 │
│ 0.026826 ┆ 0.999928 ┆ 50 │
│ 0.026812 ┆ 0.999928 ┆ 51 │
│ … ┆ … ┆ … │
│ 0.010928 ┆ 0.999928 ┆ 62 │
│ 0.005872 ┆ 0.999928 ┆ 39 │
│ 0.004155 ┆ 0.999928 ┆ 63 │
│ 0.000072 ┆ 0.999928 ┆ 30 │
│ 0.000054 ┆ 0.999928 ┆ 29 │
Here, it can be seen that the age feature value of 54 has the highest information gain. This feature value will be collected for the age feature and needs to compete against the other features.
Selecting Best Split and Define Sub Trees
To select the best split, the highest information gain needs to be found across all features.
if isinstance(information_gain_dfs[0], pl.LazyFrame):
information_gain_dfs = pl.collect_all(information_gain_dfs, streaming=self.streaming)
information_gain_dfs = pl.concat(information_gain_dfs, how="vertical_relaxed").sort(
"information_gain", descending=True
For that, the pl.collect_all() method is used on information_gain_dfs. This evaluates all LazyFrames in parallel, which makes the processing very efficient. The result is a list of polars DataFrames, which are concatenated and sorted by information gain.
For the heart disease example, the data looks like this:
│ information_gain ┆ parent_entropy ┆ feature_value ┆ feature │
│ --- ┆ --- ┆ --- ┆ --- │
│ f64 ┆ f64 ┆ f64 ┆ str │
│ 0.138032 ┆ 0.999909 ┆ 129.0 ┆ ap_hi │
│ 0.09087 ┆ 0.999909 ┆ 85.0 ┆ ap_lo │
│ 0.029966 ┆ 0.999909 ┆ 0.0 ┆ cholesterol │
│ 0.028388 ┆ 0.999909 ┆ 54.0 ┆ age_years │
│ 0.01968 ┆ 0.999909 ┆ 27.435041 ┆ bmi │
│ … ┆ … ┆ … ┆ … │
│ 0.000851 ┆ 0.999909 ┆ 0.0 ┆ active │
│ 0.000351 ┆ 0.999909 ┆ 156.0 ┆ height │
│ 0.000223 ┆ 0.999909 ┆ 0.0 ┆ smoke │
│ 0.000098 ┆ 0.999909 ┆ 0.0 ┆ alco │
│ 0.000031 ┆ 0.999909 ┆ 0.0 ┆ gender │
Out of all features, the ap_hi (Systolic blood pressure) feature value of 129 results in the best information gain and thus will be selected for the first split.
information_gain = 0
if len(information_gain_dfs) > 0:
best_params = information_gain_dfs.row(0, named=True)
information_gain = best_params["information_gain"]
In some cases, information_gain_dfs might be empty, for example, when all splits result in having only examples on the left or right side. If this is the case, the information gain is zero. Otherwise, we get the feature value with the highest information gain.
if information_gain > 0:
left_mask =["feature"]) <= best_params["feature_value"])
if isinstance(left_mask, pl.LazyFrame):
left_mask = left_mask.collect(streaming=self.streaming)
left_mask = left_mask["filter"]
# Split data
left_df = data.filter(left_mask)
right_df = data.filter(~left_mask)
left_subtree = self._build_tree(left_df, feature_names, target_name, unique_targets, depth + 1)
right_subtree = self._build_tree(right_df, feature_names, target_name, unique_targets, depth + 1)
if isinstance(data, pl.LazyFrame):
target_distribution = (
target_distribution = data[target_name].value_counts().sort(target_name)["count"].to_list()
return {
"type": "node",
"feature": best_params["feature"],
"threshold": best_params["feature_value"],
"information_gain": best_params["information_gain"],
"entropy": best_params["parent_entropy"],
"target_distribution": target_distribution,
"left": left_subtree,
"right": right_subtree,
return {"type": "leaf", "value": self.get_majority_class(data, target_name)}
When the information gain is greater than zero, the sub-trees are defined. For that, the left mask is defined using the feature value that resulted in the best information gain. The mask is applied to the parent data to get the left data frame. The negation of the left mask is used to define the right data frame. Both left and right data frames are used to call the _build_tree() method again with an increased depth+1. As the last step, the target distribution is calculated. This is used as additional information on the node and will be visible when plotting the tree along with the other information.
When information gain is zero, a leaf instance will be returned. This contains the majority class of the given data.
Make predictions
It is possible to make predictions in two different ways. If the input data is small, the predict() method can be used.
def predict(self, data: Iterable[dict]):
def _predict_sample(node, sample):
if node["type"] == "leaf":
return node["value"]
if sample[node["feature"]] <= node["threshold"]:
return _predict_sample(node["left"], sample)
return _predict_sample(node["right"], sample)
predictions = [_predict_sample(self.tree, sample) for sample in data]
return predictions
Here, the data can be provided as an iterable of dicts. Each dict contains the feature names as keys and the feature values as values. By using the _predict_sample() method, the path in the tree is followed until a leaf node is reached. This contains the class that is assigned to the respective example.
def predict_many(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> List[Union[int, float]]:
Predict method.
:param data: Polars DataFrame or LazyFrame.
:return: List of predicted target values.
if self.categorical_mappings:
data = self.apply_categorical_mappings(data)
def _predict_many(node, temp_data):
if node["type"] == "node":
left = _predict_many(node["left"], temp_data.filter(pl.col(node["feature"]) <= node["threshold"]))
right = _predict_many(node["right"], temp_data.filter(pl.col(node["feature"]) > node["threshold"]))
return pl.concat([left, right], how="diagonal_relaxed")
return"temp_prediction_index"), pl.lit(node["value"]).alias("prediction"))
data = data.with_row_index("temp_prediction_index")
predictions = _predict_many(self.tree, data).sort("temp_prediction_index").select(pl.col("prediction"))
# Convert predictions to a list
if isinstance(predictions, pl.LazyFrame):
# Despite the execution plans says there is no streaming, using streaming here significantly
# increases the performance and decreases the memory food print.
predictions = predictions.collect(streaming=True)
predictions = predictions["prediction"].to_list()
return predictions
If a big example set should be predicted, it is more efficient to use the predict_many() method. This makes use of the advantages that polars provides in terms of parallel processing and memory efficiency.
The data can be provided as a polars DataFrame or LazyFrame. Similarly to the _build_tree() method in the training process, a _predict_many() method is called recursively. All examples in the data are filtered into sub-trees until the leaf node is reached. Examples that went the same path to the leaf node get the same prediction value assigned. At the end of the process, all sub-frames of examples are concatenated again. Since the order can not be preserved with that, a temporary prediction index is set at the beginning of the process. When all predictions are done, the original order is restored with sorting by that index.
Using the classifier on a dataset
A usage example for the decision tree classifier can be found here. The decision tree is trained on a heart disease dataset. A train and test set is defined to test the performance of the implementation. After the training, the tree is plotted and saved to a file.
With a max depth of four, the resulting tree looks as follows:

It achieves a train and test accuracy of 73% on the given data.
Runtime comparison
One goal of using polars as a backend for decision trees is to explore the runtime and memory usage and compare it to other frameworks. For that, I created a memory profiling script that can be found here.
The script compares this implementation, which is called “efficient-trees” against sklearn and lightgbm. For efficient-trees, the lazy streaming variant and non-lazy in-memory variant are tested.

In the graph, it can be seen that lightgbm is the fastest and most memory-efficient framework. Since it introduced the possibility of using arrow datasets a while ago, the data can be processed efficiently. However, since the whole dataset still needs to be loaded and can’t be streamed, there are still potential scaling issues.
The next best framework is efficient-trees without and with streaming. While efficient-trees without streaming has a better runtime, the streaming variant uses less memory.
The sklearn implementation achieves the worst results in terms of memory usage and runtime. Since the data needs to be provided as a numpy array, the memory usage grows a lot. The runtime can be explained by using only one CPU core. Support for multi-threading or multi-processing doesn’t exist yet.
Deep dive: Streaming in polars
As can be seen in the comparison of the frameworks, the possibility of streaming the data instead of having it in memory makes a difference to all other frameworks. However, the streaming engine is still considered an experimental feature, and not all operations are compatible with streaming yet.
To get a better understanding of what happens in the background, a look into the execution plan is useful. Let’s jump back into the training process and get the execution plan for the following operation:
def fit(self, data: Union[pl.DataFrame, pl.LazyFrame], target_name: str) -> None:
Fit method to train the decision tree.
:param data: Polars DataFrame or LazyFrame containing the training data.
:param target_name: Name of the target column
columns = data.collect_schema().names()
feature_names = [col for col in columns if col != target_name]
# Shrink dtypes
data =
The execution plan for data can be created with the following command:
This returns the execution plan for the LazyFrame.
SELECT [col("gender").shrink_dtype(), col("height").shrink_dtype(), col("weight").shrink_dtype(), col("ap_hi").shrink_dtype(), col("ap_lo").shrink_dtype(), col("cholesterol").shrink_dtype(), col("gluc").shrink_dtype(), col("smoke").shrink_dtype(), col("alco").shrink_dtype(), col("active").shrink_dtype(), col("cardio").shrink_dtype(), col("age_years").shrink_dtype(), col("bmi").shrink_dtype()] FROM
DF ["gender", "height", "weight", "ap_hi"]; PROJECT 13/13 COLUMNS; SELECTION: None
The keyword that is important here is STREAMING. It can be seen that the initial dataset loading happens in the streaming mode, but when shrinking the dtypes, the whole dataset needs to be loaded into memory. Since the dtype shrinking is not a necessary part, I remove it temporarily to explore until what operation streaming is supported.
The next problematic operation is assigning the categorical features.
def apply_categorical_mappings(self, data: Union[pl.DataFrame, pl.LazyFrame]) -> Union[pl.DataFrame, pl.LazyFrame]:
Apply categorical mappings on input frame.
:param data: Polars DataFrame or LazyFrame with categorical columns.
:return: Polars DataFrame or LazyFrame with mapped categorical columns
return data.with_columns(
[pl.col(col).replace(self.categorical_mappings[col]).cast(pl.UInt32) for col in self.categorical_columns]
The replace expression doesn’t support the streaming mode. Even after removing the cast, streaming is not used which can be seen in the execution plan.
[col("gender").replace([Series, Series]), col("cholesterol").replace([Series, Series]), col("gluc").replace([Series, Series]), col("smoke").replace([Series, Series]), col("alco").replace([Series, Series]), col("active").replace([Series, Series])]
DF ["gender", "height", "weight", "ap_hi"]; PROJECT */13 COLUMNS; SELECTION: None
Moving on, I also remove the support for categorical features. What happens next is the calculation of the information gain.
information_gain_df = (
.filter(pl.col(target_name) == target_value)
for target_value in unique_targets
+ [pl.col(target_name).len().alias("count_examples")]
Unfortunately, already in the first part of calculating, the streaming mode is not supported anymore. Here, using pl.col().filter() prevents us from streaming the data.
SORT BY [col("feature_value")]
[col("cardio").filter([(col("cardio")) == (1)]).count().alias("class_1_count"), col("cardio").filter([(col("cardio")) == (0)]).count().alias("class_0_count"), col("cardio").count().alias("count_examples")] BY [col("feature_value")] FROM
simple π 2/2 ["gender", "cardio"]
DF ["gender", "height", "weight", "ap_hi"]; PROJECT 2/13 COLUMNS; SELECTION: col("gender").is_not_null()
Since this is not so easy to change, I will stop the exploration here. It can be concluded that in the decision tree implementation with polars backend, the full potential of streaming can’t be used yet since important operators are still missing streaming support. Since the streaming mode is under active development, it might be possible to run most of the operators or even the whole calculation of the decision tree in the streaming mode in the future.
In this blog post, I presented my custom implementation of a decision tree using polars as a backend. I showed implementation details and compared it to other decision tree frameworks. The comparison shows that this implementation can outperform sklearn in terms of runtime and memory usage. But there are still other frameworks like lightgbm that provide a better runtime and more efficient processing. There is a lot of potential in the streaming mode when using polars backend. Currently, some operators prevent an end-to-end streaming approach due to a lack of streaming support, but this is under active development. When polars makes progress with that, it is worth revisiting this implementation and comparing it to other frameworks again.
Build a Decision Tree in Polars from Scratch was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.