Causal AI, exploring the integration of causal reasoning into machine learning

Ryan O'Sullivan

Towards Data Science

This article gives a practical introduction to the potential of causal graphs.

It is aimed at anyone who wants to understand more about:

  • What causal graphs are and how they work
  • A worked case study in Python illustrating how to build causal graphs
  • How they compare to ML
  • The key challenges and future considerations

The full notebook can be found here:

Causal graphs help us disentangle causes from correlations. They are a key part of the causal inference/causal ML/causal AI toolbox and can be used to answer causal questions.

Often referred to as a DAG (directed acyclic graph), a causal graph contains nodes and edges — Edges link nodes that are causally related.

There are two ways to determine a causal graph:

  • Expert domain knowledge
  • Causal discovery algorithms

For now, we will assume we have expert domain knowledge to determine the causal graph (we will cover causal discovery algorithms further down the line).

The objective of ML is to classify or predict as accurately as possible given some training data. There is no incentive for an ML algorithm to ensure the features it uses are causally linked to the target. There is no guarantee that the direction (positive/negative effect) and strength of each feature will align with the true data generating process. ML won’t take into account the following situations:

  • Spurious correlations — Two variables having a spurious correlation when they have a common cause e.g. High temperatures increasing the number of ice cream sales and shark attacks.
  • Confounders — A variable is affecting your treatment and outcome e.g. Demand affecting how much we spend on marketing and how many new customers sign up.
  • Colliders — A variable that is affected by two independent variables e.g. Quality of customer care -> User satisfaction <- Size of company
  • Mediators — Two variables being (indirectly) linked through a mediator e.g. Regular exercise -> Cardiovascular fitness (the mediator) -> Overall health

Because of these complexities and the black-box nature of ML, we can’t be confident in its ability to answer causal questions.

Given a known causal graph and observed data, we can train a structural causal model (SCM). An SCM can be thought of as a series of causal models, one per node. Each model uses one node as a target, and its direct parents as features. If the relationships in our observed data are linear, an SCM will be a series of linear equations. This could be modelled by a series of linear regression models. If the relationships in our observed data are non-linear, this could be modelled with a series of boosted trees.

The key difference to traditional ML is that an SCM models causal relationships and accounts for spurious correlations, confounders, colliders and mediators.

It is common to use an additive noise model (ANM) for each non-root node (meaning it has at least one parent). This allows us to use a range of machine learning algorithms (plus a noise term) to estimate each non-root node.

Y := f(X) + N

Root nodes can modelled using a stochastic model to describe the distribution.

An SCM can be seen as a generative model as can to generate new samples of data — This enables it to answer a range of causal questions. It generates new data by sampling from the root nodes and then propagating data through the graph.

The value of an SCM is that it allows us to answer causal questions by calculating counterfactuals and simulating interventions:

  • Counterfactuals: Using historically observed data to calculate what would have happened to y if we had changed x. e.g. What would have happened to the number of customers churning if we had reduced call waiting time by 20% last month?
  • Interventions: Very similar to counterfactuals (and often used interchangeably) but interventions simulate what what would happen in the future e.g. What will happen to the number of customers churning if we reduce call waiting time by 20% next year?

There are several KPIs that the customer service team monitors. One of these is call waiting times. Increasing the number of call centre staff will decrease call waiting times.

But how will decreasing call waiting time impact customer churn levels? And will this offset the cost of additional call centre staff?

The Data Science team is asked to build and evaluate the business case.

The population of interest is customers who make an inbound call. The following time-series data is collected daily:

Using Causal Graphs to answer causal questions | by Ryan O'Sullivan | Jan, 2024 - image  on
Image by author

In this example, we use time-series data but causal graphs can also work with customer-level data.

In this example, we use expert domain knowledge to determine the causal graph.

# Create node lookup for channels
node_lookup = {0: 'Demand',
1: 'Call waiting time',
2: 'Call abandoned',
3: 'Reported problems',
4: 'Discount sent',
5: 'Churn'

total_nodes = len(node_lookup)

# Create adjacency matrix - this is the base for our graph
graph_actual = np.zeros((total_nodes, total_nodes))

# Create graph using expert domain knowledge
graph_actual[0, 1] = 1.0 # Demand -> Call waiting time
graph_actual[0, 2] = 1.0 # Demand -> Call abandoned
graph_actual[0, 3] = 1.0 # Demand -> Reported problems
graph_actual[1, 2] = 1.0 # Call waiting time -> Call abandoned
graph_actual[1, 5] = 1.0 # Call waiting time -> Churn
graph_actual[2, 3] = 1.0 # Call abandoned -> Reported problems
graph_actual[2, 5] = 1.0 # Call abandoned -> Churn
graph_actual[3, 4] = 1.0 # Reported problems -> Discount sent
graph_actual[3, 5] = 1.0 # Reported problems -> Churn
graph_actual[4, 5] = 1.0 # Discount sent -> Churn

Using Causal Graphs to answer causal questions | by Ryan O'Sullivan | Jan, 2024 - image  on
Image by author

Next, we need to generate data for our case study.

We want to generate some data which will allow us to compare calculating counterfactuals using causal graphs vs ML (to keep things simple, ridge regression).

As we identified the causal graph in the last section, we can use this knowledge to create a data-generating process.

def data_generator(max_call_waiting, inbound_calls, call_reduction):
A data generating function that has the flexibility to reduce the value of node 0 (Call waiting time) - this enables us to calculate ground truth counterfactuals

max_call_waiting (int): Maximum call waiting time in seconds
inbound_calls (int): Total number of inbound calls (observations in data)
call_reduction (float): Reduction to apply to call waiting time

DataFrame: Generated data

df = pd.DataFrame(columns=node_lookup.values())

df[node_lookup[0]] = np.random.randint(low=10, high=max_call_waiting, size=(inbound_calls)) # Demand
df[node_lookup[1]] = (df[node_lookup[0]] * 0.5) * (call_reduction) + np.random.normal(loc=0, scale=40, size=inbound_calls) # Call waiting time
df[node_lookup[2]] = (df[node_lookup[1]] * 0.5) + (df[node_lookup[0]] * 0.2) + np.random.normal(loc=0, scale=30, size=inbound_calls) # Call abandoned
df[node_lookup[3]] = (df[node_lookup[2]] * 0.6) + (df[node_lookup[0]] * 0.3) + np.random.normal(loc=0, scale=20, size=inbound_calls) # Reported problems
df[node_lookup[4]] = (df[node_lookup[3]] * 0.7) + np.random.normal(loc=0, scale=10, size=inbound_calls) # Discount sent
df[node_lookup[5]] = (0.10 * df[node_lookup[1]] ) + (0.30 * df[node_lookup[2]]) + (0.15 * df[node_lookup[3]]) + (-0.20 * df[node_lookup[4]]) # Churn

return df

# Generate data
df = data_generator(max_call_waiting=600, inbound_calls=10000, call_reduction=1.00)


Using Causal Graphs to answer causal questions | by Ryan O'Sullivan | Jan, 2024 - image  on
Image by author

We now have an adjacency matrix which represents our causal graph and some data. We use the gcm module from the dowhy Python package to train an SCM.

It’s important to think about what causal mechanism to use for the root and non-root nodes. If you look at our data generator function, you will see all of the relationships are linear. Therefore choosing ridge regression should be sufficient.

# Setup graph
graph = nx.from_numpy_array(graph_actual, create_using=nx.DiGraph)
graph = nx.relabel_nodes(graph, node_lookup)

# Create SCM
causal_model = gcm.InvertibleStructuralCausalModel(graph)
causal_model.set_causal_mechanism('Demand', gcm.EmpiricalDistribution()) # Root node
causal_model.set_causal_mechanism('Call waiting time', gcm.AdditiveNoiseModel( # Non-root node
causal_model.set_causal_mechanism('Call abandoned', gcm.AdditiveNoiseModel( # Non-root node
causal_model.set_causal_mechanism('Reported problems', gcm.AdditiveNoiseModel( # Non-root node
causal_model.set_causal_mechanism('Discount sent', gcm.AdditiveNoiseModel( # Non-root
causal_model.set_causal_mechanism('Churn', gcm.AdditiveNoiseModel( # Non-root, df)

You could also use the auto assignment function to automatically assign the causal mechanisms instead of manually assigning them.

For more info on the gcm package see the docs:

We also use ridge regression to help create a baseline comparison. We can look back at the data generator and see that it correctly estimates the coefficients for each variable. However, in addition to directly influencing churn, call waiting time indirectly influences churn through abandoned calls, reported problems and discounts sent.

When it comes to estimating counterfactuals it is going to be interesting to see how the SCM compares to ridge regression.

# Ridge regression
y = df['Churn'].copy()
X = df.iloc[:, 1:-1].copy()
model = RidgeCV()
model =, y)
y_pred = model.predict(X)

print(f'Intercept: {model.intercept_}')
print(f'Coefficient: {model.coef_}')
# Ground truth[0.10 0.30 0.15 -0.20]

Using Causal Graphs to answer causal questions | by Ryan O'Sullivan | Jan, 2024 - image  on
Image by author

Before we move on to calculating counterfactuals using causal graphs and ridge regression, we need a ground truth benchmark. We can use our data generator to create counterfactual samples after we have reduced call waiting time by 20%.

We couldn’t do this with real-world problems but this method allows us to assess how effective the causal graph and ridge regression is.

# Set call reduction to 20%
reduce = 0.20
call_reduction = 1 - reduce

# Generate counterfactual data
df_cf = data_generator(max_call_waiting=600, inbound_calls=10000, call_reduction=call_reduction)

We can now estimate what would have happened if we had of decreased the call waiting time by 20% using our 3 methods:

  • Ground truth (from the data generator)
  • Ridge regression
  • Causal graph

We see that ridge regression underestimates the impact on churn significantly whilst the causal graph is very close to the ground truth.

# Ground truth counterfactual
ground_truth = round((df['Churn'].sum() - df_cf['Churn'].sum()) / df['Churn'].sum(), 2)

# Causal graph counterfactual
df_counterfactual = gcm.counterfactual_samples(causal_model, {'Call waiting time': lambda x: x*call_reduction}, observed_data=df)
causal_graph = round((df['Churn'].sum() - df_counterfactual['Churn'].sum()) / (df['Churn'].sum()), 3)

# Ridge regression counterfactual
ridge_regression = round((df['Call waiting time'].sum() * 1.0 * model.coef_[0] - (df['Call waiting time'].sum() * call_reduction * model.coef_[0])) / (df['Churn'].sum()), 3)

Using Causal Graphs to answer causal questions | by Ryan O'Sullivan | Jan, 2024 - image  on
Image by author

This was a simple example to start you thinking about the power of causal graphs.

For more complex situations, several challenges that would need some consideration:

  • What assumptions are made and what is the impact of these being violated?
  • What about if we don’t have the expert domain knowledge to identify the causal graph?
  • What if there are non-linear relationships?
  • How damaging is multi-collinearity?
  • What if some variables have lagged effects?
  • How can we deal with high-dimensional datasets (lots of variables)?

All of these points will be covered in future blogs.

If your interested in learning more about causal AI, I highly recommend the following resources:

“Meet Ryan, a seasoned Lead Data Scientist with a specialized focus on employing causal techniques within business contexts, spanning Marketing, Operations, and Customer Service. His proficiency lies in unraveling the intricacies of cause-and-effect relationships to drive informed decision-making and strategic enhancements across diverse organizational functions.”

Source link