Hennie de Harder, Author at Towards Data Science https://towardsdatascience.com The world’s leading publication for data science, AI, and ML professionals. Tue, 01 Apr 2025 06:21:31 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Hennie de Harder, Author at Towards Data Science https://towardsdatascience.com 32 32 Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure https://towardsdatascience.com/graph-neural-networks-part-3-how-graphsage-handles-changing-graph-structure/ Tue, 01 Apr 2025 06:21:17 +0000 https://towardsdatascience.com/?p=605370 And how you can use it for large graphs

The post Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure appeared first on Towards Data Science.

]]>
In the previous parts of this series, we looked at Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs). Both architectures work fine, but they also have some limitations! A big one is that for large graphs, calculating the node representations with GCNs and GATs will become v-e-r-y slow. Another limitation is that if the graph structure changes, GCNs and GATs will not be able to generalize. So if nodes are added to the graph, a GCN or GAT cannot make predictions for it. Luckily, these issues can be solved!

In this post, I will explain Graphsage and how it solves common problems of GCNs and GATs. We will train GraphSAGE and use it for graph predictions to compare performance with GCNs and GATs.

New to GNNs? You can start with post 1 about GCNs (also containing the initial setup for running the code samples), and post 2 about GATs


Two Key Problems with GCNs and GATs

I shortly touched upon it in the introduction, but let’s dive a bit deeper. What are the problems with the previous GNN models?

Problem 1. They don’t generalize

GCNs and GATs struggle with generalizing to unseen graphs. The graph structure needs to be the same as the training data. This is known as transductive learning, where the model trains and makes predictions on the same fixed graph. It is actually overfitting to specific graph topologies. In reality, graphs will change: Nodes and edges can be added or removed, and this happens often in real world scenarios. We want our GNNs to be capable of learning patterns that generalize to unseen nodes, or to entirely new graphs (this is called inductive learning).

Problem 2. They have scalability issues

Training GCNs and GATs on large-scale graphs is computationally expensive. GCNs require repeated neighbor aggregation, which grows exponentially with graph size, while GATs involve (multihead) attention mechanisms that scale poorly with increasing nodes.
In big production recommendation systems that have large graphs with millions of users and products, GCNs and GATs are impractical and slow.

Let’s take a look at GraphSAGE to fix these issues.

GraphSAGE (SAmple and aggreGatE)

GraphSAGE makes training much faster and scalable. It does this by sampling only a subset of neighbors. For super large graphs it’s computationally impossible to process all neighbors of a node (except if you have limitless time, which we all don’t…), like with traditional GCNs. Another important step of GraphSAGE is combining the features of the sampled neighbors with an aggregation function
We will walk through all the steps of GraphSAGE below.

1. Sampling Neighbors

With tabular data, sampling is easy. It’s something you do in every common machine learning project when creating train, test, and validation sets. With graphs, you cannot select random nodes. This can result in disconnected graphs, nodes without neighbors, etcetera:

Randomly selecting nodes, but some are disconnected. Image by author.

What you can do with graphs, is selecting a random fixed-size subset of neighbors. For example in a social network, you can sample 3 friends for each user (instead of all friends):

Randomly selecting three rows in the table, all neighbors selected in the GCN, three neighbors selected in GraphSAGE. Image by author.

2. Aggregate Information

After the neighbor selection from the previous part, GraphSAGE combines their features into one single representation. There are multiple ways to do this (multiple aggregation functions). The most common types and the ones explained in the paper are mean aggregationLSTM, and pooling

With mean aggregation, the average is computed over all sampled neighbors’ features (very simple and often effective). In a formula:

LSTM aggregation uses an LSTM (type of neural network) to process neighbor features sequentially. It can capture more complex relationships, and is more powerful than mean aggregation. 

The third type, pool aggregation, applies a non-linear function to extract key features (think about max-pooling in a neural network, where you also take the maximum value of some values).

3. Update Node Representation

After sampling and aggregation, the node combines its previous features with the aggregated neighbor features. Nodes will learn from their neighbors but also keep their own identity, just like we saw before with GCNs and GATs. Information can flow across the graph effectively. 

This is the formula for this step:

The aggregation of step 2 is done over all neighbors, and then the feature representation of the node is concatenated. This vector is multiplied by the weight matrix, and passed through non-linearity (for example ReLU). As a final step, normalization can be applied.

4. Repeat for Multiple Layers

The first three steps can be repeated multiple times, when this happens, information can flow from distant neighbors. In the image below you see a node with three neighbors selected in the first layer (direct neighbors), and two neighbors selected in the second layer (neighbors of neighbors). 

Selected node with selected neighbors, three in the first layer, two in the second layer. Interesting to note is that one of the neighbors of the nodes in the first step is the selected node, so that one can also be selected when two neighbors are selected in the second step (just a bit harder to visualize). Image by author.

To summarize, the key strengths of GraphSAGE are its scalability (sampling makes it efficient for massive graphs); flexibility, you can use it for Inductive learning (works well when used for predicting on unseen nodes and graphs); aggregation helps with generalization because it smooths out noisy features; and the multi-layers allow the model to learn from far-away nodes.

Cool! And the best thing, GraphSAGE is implemented in PyG, so we can use it easily in PyTorch.

Predicting with GraphSAGE

In the previous posts, we implemented an MLP, GCN, and GAT on the Cora dataset (CC BY-SA). To refresh your mind a bit, Cora is a dataset with scientific publications where you have to predict the subject of each paper, with seven classes in total. This dataset is relatively small, so it might be not the best set for testing GraphSAGE. We will do this anyway, just to be able to compare. Let’s see how well GraphSAGE performs.

Interesting parts of the code I like to highlight related to GraphSAGE:

  • The NeighborLoader that performs selecting the neighbors for each layer:
from torch_geometric.loader import NeighborLoader

# 10 neighbors sampled in the first layer, 10 in the second layer
num_neighbors = [10, 10]

# sample data from the train set
train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=data.train_mask,
)
  • The aggregation type is implemented in the SAGEConv layer. The default is mean, you can change this to max or lstm:
from torch_geometric.nn import SAGEConv

SAGEConv(in_c, out_c, aggr='mean')
  • Another important difference is that GraphSAGE is trained in mini batches, and GCN and GAT on the full dataset. This touches the essence of GraphSAGE, because the neighbor sampling of GraphSAGE makes it possible to train in mini batches, we don’t need the full graph anymore. GCNs and GATs do need the complete graph for correct feature propagation and calculation of attention scores, so that’s why we train GCNs and GATs on the full graph.
  • The rest of the code is similar as before, except that we have one class where all different models are instantiated based on the model_type (GCN, GAT, or SAGE). This makes it easy to compare or make small changes.

This is the complete script, we train 100 epochs and repeat the experiment 10 times to calculate average accuracy and standard deviation for each model:

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader

# dataset_name can be 'Cora', 'CiteSeer', 'PubMed'
dataset_name = 'Cora'
hidden_dim = 64
num_layers = 2
num_neighbors = [10, 10]
batch_size = 128
num_epochs = 100
model_types = ['GCN', 'GAT', 'SAGE']

dataset = Planetoid(root='data', name=dataset_name)
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='SAGE', gat_heads=8):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.model_type = model_type
        self.gat_heads = gat_heads

        def get_conv(in_c, out_c, is_final=False):
            if model_type == 'GCN':
                return GCNConv(in_c, out_c)
            elif model_type == 'GAT':
                heads = 1 if is_final else gat_heads
                concat = False if is_final else True
                return GATConv(in_c, out_c, heads=heads, concat=concat)
            else:
                return SAGEConv(in_c, out_c, aggr='mean')

        if model_type == 'GAT':
            self.convs.append(get_conv(in_channels, hidden_channels))
            in_dim = hidden_channels * gat_heads
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(in_dim, hidden_channels))
                in_dim = hidden_channels * gat_heads
            self.convs.append(get_conv(in_dim, out_channels, is_final=True))
        else:
            self.convs.append(get_conv(in_channels, hidden_channels))
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(hidden_channels, hidden_channels))
            self.convs.append(get_conv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        x = self.convs[-1](x, edge_index)
        return x

@torch.no_grad()
def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

results = {}

for model_type in model_types:
    print(f'Training {model_type}')
    results[model_type] = []

    for i in range(10):
        model = GNN(dataset.num_features, hidden_dim, dataset.num_classes, num_layers, model_type, gat_heads=8).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        if model_type == 'SAGE':
            train_loader = NeighborLoader(
                data,
                num_neighbors=num_neighbors,
                batch_size=batch_size,
                input_nodes=data.train_mask,
            )

            def train():
                model.train()
                total_loss = 0
                for batch in train_loader:
                    batch = batch.to(device)
                    optimizer.zero_grad()
                    out = model(batch.x, batch.edge_index)
                    loss = F.cross_entropy(out, batch.y[:out.size(0)])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                return total_loss / len(train_loader)

        else:
            def train():
                model.train()
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                return loss.item()

        best_val_acc = 0
        best_test_acc = 0
        for epoch in range(1, num_epochs + 1):
            loss = train()
            train_acc, val_acc, test_acc = test(model)
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            if epoch % 10 == 0:
                print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}')

        results[model_type].append([best_val_acc, best_test_acc])

for model_name, model_results in results.items():
    model_results = torch.tensor(model_results)
    print(f'{model_name} Val Accuracy: {model_results[:, 0].mean():.3f} ± {model_results[:, 0].std():.3f}')
    print(f'{model_name} Test Accuracy: {model_results[:, 1].mean():.3f} ± {model_results[:, 1].std():.3f}')

And here are the results:

GCN Val Accuracy: 0.791 ± 0.007
GCN Test Accuracy: 0.806 ± 0.006
GAT Val Accuracy: 0.790 ± 0.007
GAT Test Accuracy: 0.800 ± 0.004
SAGE Val Accuracy: 0.899 ± 0.005
SAGE Test Accuracy: 0.907 ± 0.004

Impressive improvement! Even on this small dataset, GraphSAGE outperforms GAT and GCN easily! I repeated this test for CiteSeer and PubMed datasets, and always GraphSAGE came out best. 

What I like to note here is that GCN is still very useful, it’s one of the most effective baselines (if the graph structure allows it). Also, I didn’t do much hyperparameter tuning, but just went with some standard values (like 8 heads for the GAT multi-head attention). In larger, more complex and noisier graphs, the advantages of GraphSAGE become more clear than in this example. We didn’t do any performance testing, because for these small graphs GraphSAGE isn’t faster than GCN.


Conclusion

GraphSAGE brings us very nice improvements and benefits compared to GATs and GCNs. Inductive learning is possible, GraphSAGE can handle changing graph structures quite well. And we didn’t test it in this post, but neighbor sampling makes it possible to create feature representations for larger graphs with good performance. 

Related

Optimizing Connections: Mathematical Optimization within Graphs

Graph Neural Networks Part 1. Graph Convolutional Networks Explained

Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs

The post Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure appeared first on Towards Data Science.

]]>
You Think 80% Means 80%? Why Prediction Probabilities Need a Second Look https://towardsdatascience.com/you-think-80-means-80-why-prediction-probabilities-need-a-second-look-ebc8b650cf21/ Tue, 14 Jan 2025 17:34:25 +0000 https://towardsdatascience.com/you-think-80-means-80-why-prediction-probabilities-need-a-second-look-ebc8b650cf21/ Understand the gap between predicted probabilities and real-world outcomes

The post You Think 80% Means 80%? Why Prediction Probabilities Need a Second Look appeared first on Towards Data Science.

]]>
Real world occurrences versus model confidence scores. Image created with Dall·E by the author.
Real world occurrences versus model confidence scores. Image created with Dall·E by the author.

How reliable are probabilities predicted by a machine learning model? What does a predicted probability of 80% mean? Is it similar to 80% chance of an event occurring? In this beginner friendly post, you’ll learn the basics of prediction probabilities, Calibration, and how to interpret these numbers in a practical context. I will show with a demo how you can evaluate and improve these probabilities for better decision-making.


What do prediction probabilities represent?

Instead of calling model.predict(data), which gives you a 0 or 1 prediction for a binary classification problem, you might have used model.predict_proba(data). This will give you probabilities instead of zeroes and ones. In many data science cases this is useful, because it gives you more insights. But what do these probabilities actually mean?

A predicted probability of 0.8 means that the model is 80% confident that an instance belongs to the positive class. Let’s repeat that: the model is 80% confident that an instance belongs to the positive class. So it doesn’t mean: there is an 80% real-world likelihood of the event occurring. These are two different things. It’s important to make this distinction, especially in use cases where probabilities drive the -for example in fraud detection, medical diagnoses, or risk assessment.

This touches my motivation to write this post: I noticed that many people easily make this mistake. As a data scientist, it can be important to explain this difference if your model is ‘guilty’, because business stakeholders might assume that 80% prediction probability means an 80% real-world likelihood of the event occurring.

The problem with interpreting model confidence scores as probabilities

Why can it be a problem to interpret scores from predict_proba as real-world likelihood?

Some models are overconfident: An overconfident model produces high confidence scores while not being right. Other models are underconfident: They produce low confidence scores while being right. If model confidence scores are not calibrated this can be misleading.

Another thing to keep in mind is that confidence scores are based on patterns learned from the training data. If the data is imbalanced or biased, the confidence scores might not reflect true probabilities in real-world scenarios.

Let’s take a look at an example. It can happen that when your XGBoost model predicts 80% as probability, that in reality, the event occurs only 65% of the time when the model outputs 80%. Of course we would like to see that if the model predicted 80% for 100 cases, the event occurs in approximately 80 of the cases. Otherwise we can’t trust the probabilities.

How can we determine if a model is well-calibrated, meaning that the model confidence scores match the true likelihood of the event? Let’s take a look at calibration and ways to improve it.

Calibrating model confidence scores

First, we want to visualize the alignment between model confidence scores and true outcomes on the test set. It’s super easy:

  1. Group prediction probabilities into bins, below I used 20 bins.
  2. Compute the fraction of positive cases for each bin. This corresponds with the true probability of the event occurring in this bin.
  3. Plot these true probabilities against the predicted probabilities.

And of course, if the model is perfectly calibrated, the points will lie on the diagonal line. For example: out of all cases in the 10% bin, the true probability (fraction of positives) is around 10%. Below you can see examples of a quite good calibrated XGBoost model, versus a not-so-perfect calibrated Naive Bayes model. These models are trained on the adult dataset.

Plotting calibration curves on the test set for the adult dataset for XGBoost and Naive Bayes prediction scores. Image by author.
Plotting calibration curves on the test set for the adult dataset for XGBoost and Naive Bayes prediction scores. Image by author.

Another way to check how well the model is calibrated is by using the Brier Score. This is also easy! It measures the mean squared difference between the predicted probabilities and the actual outcomes (so the lower the better):

If we calculate the Brier Score for the two models above, we get the following results:

Brier scores for adult dataset:
XGBoost:       0.10849946433956742
Naive Bayes:   0.1920520011951727

What we can conclude from the calibration plots, is that the calibration from the XGBoost model is quite good. The one for Naive Bayes is far from perfect, because the curve is deviating from the diagonal line, and the Brier Score is high (almost twice as high as the Brier Score for the XGBoost model). Let’s continue with the Naive Bayes model to show how we can improve the calibration! There are different ways of improving it, in this post we will take a look at Platt Scaling and Isotonic Regression.

The calibration curve and Brier score are implemented in scikit-learn, you can import and create them by using the following code:

from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss
from sklearn.naive_bayes import GaussianNB

# fit model on training data
model = GaussianNB()
model.fit(X_train, y_train)

# calculate predicted probabilities with pred_proba on the test set
probs = model.predict_proba(X_test)[:, 1]

brier_score = brier_score_loss(y_test, model_probs)
prob_true, prob_pred = calibration_curve(y_test, model_probs, n_bins=20, strategy='uniform')

Platt Scaling

Platt Scaling is a simple and effective method for calibrating predicted probabilities. It works by fitting a logistic regression model to the output of the uncalibrated model’s probabilities. Specifically, it minimizes the log-loss on a validation set, ensuring that the calibrated probabilities better reflect the true likelihood of the events.

To apply Platt Scaling, you split your data into training and validation sets. The first step is to train your model on the training set and generate uncalibrated probabilities for the validation set. Then you can use these probabilities as input features to fit a logistic regression model that adjusts the predictions. This approach is particularly effective for models that produce continuous scores, such as SVMs or Naive Bayes. One note: Platt Scaling assumes a monotonic relationship between predicted probabilities and true outcomes, which might not always hold.

Here you can see the code for applying Platt Scaling, and the new calibration curve if we apply Platt Scaling to our Naive Bayes classifier on the adult dataset:

from sklearn.calibration import CalibratedClassifierCV
from sklearn.naive_bayes import GaussianNB

# first, make sure you have fitted your model on the train set 
model = GaussianNB()
model.fit(X_train, y_train)

# apply Platt Scaling 
# cv='prefit' makes sure that the method uses the already trained model
platt_model = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')
platt_model.fit(X_train, y_train)
platt_probs = platt_model.predict_proba(X_test)[:, 1]

Isotonic Regression

Another common calibration technique is Isotonic Regression. This is a non-parametric technique for calibrating probabilities. Meaning that unlike Platt Scaling, it does not make assumptions, making it more flexible but also potentially prone to overfitting when you are working with a smaller dataset. This method creates a step-by-step function that adjusts the predicted probabilities so they align better with the actual outcomes. The adjustment ensures that the probabilities stay in order, meaning higher predictions will still represent a higher likelihood of the event happening compared to lower predictions.

To implement Isotonic Regression, you again split your data and train the base model on the training set. The predicted probabilities on the validation set are used as inputs to fit an isotonic regression model, which adjusts the probabilities. It tends to produce better calibration than Platt Scaling in cases where the true probability distribution is irregular, like in our example. But watch out with small datasets, because Isotonic Regression can introduce artifacts like sharp jumps or dips in the calibration curve.

Below again the code and calibration curve. You can clearly spot the jump and dip at mean predicted probability 0.6! Besides that, the curve looks nice.

from sklearn.calibration import CalibratedClassifierCV
from sklearn.naive_bayes import GaussianNB

# first, make sure you have fitted your model on the train set 
model = GaussianNB()
model.fit(X_train, y_train)

# apply isotonic regression, with method='isotonic'
iso_model = CalibratedClassifierCV(model, method='isotonic', cv='prefit')
iso_model.fit(X_train, y_train)
iso_probs = iso_model.predict_proba(X_test)[:, 1]

Comparing Calibration Methods

If we combine all the plots of the Naive Bayes model on the adult dataset (uncalibrated model, Platt Scaling, and Isotonic Regression), to compare them, this is the result:

Uncalibrated Naive Bayes model, Platt Scaling, and Isotonic Regression. Image by author.
Uncalibrated Naive Bayes model, Platt Scaling, and Isotonic Regression. Image by author.

Looking at this plot, the Isotonic Regression calibration plot seems to fit best in this example. It only has this strange jump and dip at 0.6 mean predicted probability mentioned earlier. We can perform an extra check by calculating the Brier Scores:

Brier scores for adult dataset and Naive Bayes model:
Uncalibrated:        0.1920520011951727
Platt Scaling:       0.15621506274566171
Isotonic Regression: 0.12849532236356562

Indeed! Isotonic Regression has the best score.

You may have noticed that the uncalibrated XGBoost model had an even better Brier Score and calibration plot, and you are right. We could save ourselves the hassle of calibrating the results of the Naive Bayes model and go for XGBoost for this dataset! Of course, if you test this in real life on your own data, it’s not guaranteed that this is the case 🙂


Conclusion and Further Reading

Calibration is often overlooked but can be critical in decision-sensitive applications. Many models, including tree-based ensembles like XGBoost and LightGBM, use indirect techniques to improve calibration during training, such as minimizing log-loss. This does not directly address probability calibration (especially for imbalanced datasets or datasets that contain noisy labels). It is generally good practice to validate calibration using plots and metrics like the Brier Score.

It is most crucial to calibrate your results if probabilities are used for decision making rather than ranking. In healthcare for example, a calibrated probability could help estimate risk more accurately, enabling better resource allocation. If this post sparked your interest, related topics are Bayesian methods for uncertainty quantification, advanced ensembling techniques, and deeper insights into confidence intervals for predicted probabilities. You might be interested in other calibration methods, for example Logistic Correction. Some interesting research papers about these topics:

Related

How to Compare ML Solutions Effectively?

Model-Agnostic Methods for Interpreting any Machine Learning Model

Monte Carlo Methods Decoded

The post You Think 80% Means 80%? Why Prediction Probabilities Need a Second Look appeared first on Towards Data Science.

]]>
Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs https://towardsdatascience.com/graph-neural-networks-part-2-graph-attention-networks-vs-gcns-029efd7a1d92/ Tue, 08 Oct 2024 17:34:34 +0000 https://towardsdatascience.com/graph-neural-networks-part-2-graph-attention-networks-vs-gcns-029efd7a1d92/ A model that pays attention to your graph

The post Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs appeared first on Towards Data Science.

]]>
Graph Neural Networks Part 2. Graph Attention Networks vs. Graph Convolutional Networks

Welcome to the second post about GNN architectures! In the previous post, we saw a staggering improvement in accuracy on the Cora dataset by incorporating the graph structure in the model using a Graph Convolutional Network (GCN). This post explains Graph Attention Networks (GATs), another fundamental architecture of graph neural networks. Can we improve the accuracy even further with a gat?

First, let’s talk about the difference between GATs and GCNs. Then let’s train a GAT and compare the accuracy with the GCN and basic neural network.

This blog post is part of a series. Are you new to GNNs? I recommend you to start with the first post, which explains graphs, neural networks, the dataset, and GCNs.


Graph Attention Networks

In my previous post, we saw a GCN in action. Let’s take it a step further and look at Graph Attention Networks (GATs). As you might remember, GCNs treat all neighbors equally. For GATs, this is different. GATs allow the model to learn different importance (attention) scores for different neighbors. They aggregate neighbor information by using attention mechanisms (this might ring a bell because these mechanisms are also used in transformers).

How does this work? In the GCN, we only looked at the degree of the nodes. GATs on the other hand, also take the feature values into account to assign attention scores to different neighbors.

So instead of treating all neighbors equally, an attention mechanism is introduced that assigns varying levels of importance to different neighbors. This allows the network to focus on the most relevant parts of the graph structure, essentially learning "where to look" when making predictions.

So, how exactly does the attention mechanism work in GATs? Let’s break it down.

Step 1: Computing Attention Scores

For each node, we calculate an attention score for every neighboring node. This score is a measure of how important a specific neighbor’s features are when updating the current node’s features. The score is learned during training, so the model decides which nodes matter most for each task.

There are multiple ways of computing attention scores in GATs. In this post, I explain the second version instead of the first, because most of the time this method is more effective than the original one.

Mathematically, given a node i and its neighbor j, the attention coefficient​ is computed as follows:

Feature TransformationWe start with two feature vectors of nodes i and j, and the first step is to apply a shared weight matrix W to the features:

Next, the transformed features are summed (in the original GAT version the features were concatenated):

Score CalculationNow we can calculate the raw (unnormalized) attention score, using a LeakyReLU function:

Step 2: Normalizing Attention Scores

The raw attention scores from the previous step​ are normalized across all neighbors of node i using the softmax function. This ensures that the coefficients are easy to interpret (as they sum to 1 for each node):

The normalized attention coefficients​ determine how much weight each neighbor j contributes to the new feature representation of node i.

Step 3: Aggregating Neighbor Information

Finally, the node i‘s new feature representation is computed as a weighted sum of its neighbors’ transformed features, where the weights are given by the attention coefficients​:

So now we got our attention scores and updated feature representations! Let’s continue with another important aspect of GATs, multi-head attention.

Multi-Head Attention

Just like transformers, GATs often use multi-head attention to improve their performance. But what does multi-head attention mean, and why is it so beneficial?

Multi-head attention refers to running several separate attention mechanisms, or heads, in parallel. Each of these heads independently computes attention scores for the neighbors of a node, learning to focus on different aspects of the graph structure or node features. After these heads process the graph, their outputs are either concatenated or averaged to form the final node representation.

So one of the key reasons of using multiple heads instead of one is to learn diverse patterns, because each attention head has its own learnable parameters and can learn to focus on different parts of the neighborhood. Another reason is that it stabilizes the training process. You can compare it with an ensemble, other heads can compensate for a "noisy head".

A center node with 6 neighbors. Two different attention heads are represented by the blue and green arrows. Thickness of the arrows represent the varying levels of importance (the attention scores). Image by author.
A center node with 6 neighbors. Two different attention heads are represented by the blue and green arrows. Thickness of the arrows represent the varying levels of importance (the attention scores). Image by author.

How is multi-head attention implemented in GATs? The first step is that each attention head computes its own set of attention scores and new node features independently. For N heads, and a given node i, we’ll end up with N different sets of transformed features. Next up, all outputs are concatenated (stacked) or averaged. Concatenation is more common because it increases the model’s expressiveness, but on the other hand the output dimension will be larger. Averaging helps to smooth out the differences between the heads. A general rule is to use concatenation when it’s a hidden layer in the network and averaging when it’s the last layer. When all attention heads are combined, we hope to get a comprehensive view of the graph, because the different heads have different perspectives on the relationships in the graph.

Multiple heads according to Dall·E.
Multiple heads according to Dall·E.

PyTorch Implementation

Let’s implement a GAT in python, and train it on the Cora dataset. You can use the same setup as in the previous post.

from torch_geometric.nn import GATv2Conv  # use GATConv for the first GAT version
import torch
import torch.nn.functional as F

class GAT(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, heads=8):
        super().__init__()
        self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=heads)
        # for the last GAT layer we use concat=False to average the outputs of the heads
        self.gat2 = GATv2Conv(hidden_dim * heads, output_dim, heads=heads, concat=False)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)

Looking back, the results for the MLP and GCN were as follows:

MLP Test Accuracy: 54.35 ± 1.06
GCN Test Accuracy: 78.76 ± 0.38

Will we improve this with the GAT? Let’s run the code (same as previous post):

for model_class in [MLP, GCN, GAT]:
    results[model_class.__name__] = []
    # same training loop as before...

The model is training… The GAT model takes a bit longer than the GCN and MLP…

And here is the result:

GAT Test Accuracy: 78.45 ± 1.11

The GAT performance is comparable with the GCN! This can happen, and it looks like for the Cora dataset it doesn’t matter which model we use. But we didn’t do any finetuning on both models, so maybe the GAT will be better in the end.

According to the original GAT paper (version 1), GATs outperform GCNs on the benchmark datasets.

Considerations for GATs

While GATs have shown great promise in improving accuracy (not in this post, but it’s better to trust the paper here), there are a few things to keep in mind:

  • The attention mechanism in GATs adds additional complexity to the model, both in terms of computation and the number of parameters. This makes GATs more resource-intensive and slower to train than GCNs.
  • Multi-head attention helps stabilize training, but there is still a risk of overfitting, especially when using many attention heads or deep GAT architectures. Using techniques like dropout and early stopping can help to mitigate this.
  • One advantage of GATs is that they provide interpretability through attention scores. These scores can be analyzed to understand which nodes are most influential in making predictions, offering insights into the graph structure.
  • Another point I didn’t address in the previous post is how to finetune GNNs. Many steps in finetuning GNNs are similar to traditional neural networks: testing different values for the hyperparameters and preventing overfitting with early stopping. For example with GATs you need to tune the number of attention heads. Small changes to node and edge features can have an impact on GNN performance, so it might help to experiment with different feature combinations or to create new features. Augmenting data can improve generalization. You can do this by adding noise to edges, randomly dropping nodes, or by performing subgraph sampling.

Conclusion

GATs take GCNs a step further by introducing attention mechanisms that assign different levels of importance to each node’s neighbors. This added flexibility allows GATs to achieve better performance in many cases compared to GCNs. However, this comes at the cost of increased computational complexity and the need to tune extra hyperparameters like the number of attention heads.

GATs and GCNs represent just two foundational architectures of GNNs. Each has its strengths and trade-offs, and the choice of which to use depends on the dataset and prediction task. For many tasks, GATs can offer a performance boost, especially when the relationships between nodes are not uniformly important.

Are you curious about other architectures? I’m not sure yet how I will follow up these two blog posts on GNNs. For those who can’t wait, here are some other interesting architectures and papers to investigate:

  • If you are looking for an overview of GNN models and a general design pipeline, this paper is a good place to start.
  • Relational Graph Convolutional Networks (R-GCNs) are an extension of GCNs. R-GCNs are specifically designed for situations where edges can have different types or relations. R-GCNs use relation-specific weights to handle different types of edges and their unique relationships.
  • GraphSAGE samples a fixed number of neighbors for each node, instead of aggregating features from all neighbors (like GCNs and GATs do). It’s an interesting architecture for efficient, large-scale graph representation learning.
  • SEAL is specifically designed for link prediction. It extracts a local subgraph around each target link and shows great performance. Here you can find the GitHub repo.
  • Is it possible to generate realistic graphs? GraphRNN uses auto-regressive models to model graph generation, where nodes and edges are treated as time-based events.

Related

Graph Neural Networks Part 1. Graph Convolutional Networks Explained

Optimizing Connections: Mathematical Optimization within Graphs

Simplify Your Machine Learning Projects

The post Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs appeared first on Towards Data Science.

]]>
Graph Neural Networks Part 1. Graph Convolutional Networks Explained https://towardsdatascience.com/graph-neural-networks-part-1-graph-convolutional-networks-explained-9c6aaa8a406e/ Tue, 01 Oct 2024 17:34:34 +0000 https://towardsdatascience.com/graph-neural-networks-part-1-graph-convolutional-networks-explained-9c6aaa8a406e/ Node classification with Graph Convolutional Networks

The post Graph Neural Networks Part 1. Graph Convolutional Networks Explained appeared first on Towards Data Science.

]]>
Data doesn’t always fit neat into rows and columns. Instead, it’s often the case that data follows a graph structure, take for example social networks, protein structures, recommendation or transportation systems. Leaving the information about the graph topology out of a machine learning model can decrease the performance drastically. Luckily, there is a way to include this information.

Graph Neural Networks (GNNs) are designed to learn from data represented as nodes and edges. GNNs have evolved over the years, and in this post you will learn about Graph Convolutional Networks (GCNs). My next post will cover Graph Attention Networks (GATs). GCNs and GATs are two fundamental architectures on which current state of the art models are based upon, so if you want to learn about GNNs, this is a good start. Let’s dive in!

New to graphs? The first part of this post (Graph Basics) explains the basics of graphs. Also, you should be familiar with neural networks (a short recap is provided in this article, in the part Datasets and Prerequisites).

Undirected graph with nodes (or vertices) and edges (or links). The graph doesn't contain self loops. The adjacency matrix shows the connections (edges) between nodes. Image by author.
Undirected graph with nodes (or vertices) and edges (or links). The graph doesn’t contain self loops. The adjacency matrix shows the connections (edges) between nodes. Image by author.

Why Graph Neural Networks?

A GNN is a type of neural network specifically designed to process and analyze data structured as graphs. Unlike traditional neural networks that operate on regular grid-like data (like sequences for natural language processing or grids for image data), GNNs are built to work with graph data. Each node in a graph is associated with a feature vector (or embedding) that captures its attributes or characteristics. The GNN updates these embeddings by aggregating information from neighboring nodes and edges. We will dive deeper into this in the following paragraphs of the post.

Why do we need GNNs? In essence, GNNs allow us to leverage the rich relational data present in graphs to make predictions, detect patterns, and gain insights from complex systems that are not easily represented by traditional data formats. GNNs excel at leveraging the connections between entities in a graph, capturing both local and global relationships. As we will see in this post, GNNs can outperform traditional models in tasks where relational data is crucial, offering more accurate predictions with less data.

There are three common types of prediction tasks in graphs:

  • You can predict on graph level. The input of the model is many different graphs, and every graph gets one classification. For example the class a molecule belongs to: every molecule is represented by one graph, and every molecule needs a prediction. Another example is image classification. Yes, images can also be represented as graphs (as you can see in the image below)!
  • Another way to use GNNs is by predicting on node level. The input of the GNN is one graph, and every node needs a prediction. This prediction is a characteristic of the node. The demo in this post will show node classification. Node regression is of course possible as well! Compared to classification, you only need to change the output layer activation function, the loss function, evaluation metric, and obviously the target.
  • Finally, we can predict on edge level. The value of an edge is predicted, or the likelihood of an edge that will appear soon. An example is recommended friends on social media (a.k.a. link prediction).

All these prediction tasks can be solved with GNNs!

From image to graph. Neighbor pixels are connected with an edge. Image by author.
From image to graph. Neighbor pixels are connected with an edge. Image by author.

Dataset and Prerequisites

If you want to run the code yourself, you have to create a new environment and install the dependencies:

Step 1. Create a new environment with Poetry:

poetry init --no-interaction

Step 2. Add the dependencies:

poetry add torch torchvision torchaudio torch-geometric matplotlib scikit-learn

Step 3. Activate the environment:

poetry shell

A pyproject.toml file is created with all the dependencies for the project, making sure it’s easy to reproduce. Mine looks like this:

[tool.poetry]
name = "gnnsdemo"
version = "0.1.0"
description = ""
authors = ["Hennie <31654880+henniedeharder@users.noreply.github.com>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.9"
torch = "^2.4.1"
torchvision = "^0.19.1"
torchaudio = "^2.4.1"
torch-geometric = "^2.6.1"
matplotlib = "^3.9.2"
scikit-learn = "^1.5.2"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" 

The Cora dataset is a benchmark dataset for graph neural networks. The dataset contains data about 2708 scientific publications. These publications are the nodes of the graph. An edge between nodes (publications) is created when a publication references the other one. The target is to predict the subject of each paper, there are seven classes in total. Here’s how you can load the dataset:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root=".", name="Cora")
data = dataset[0]
num_labels = len(set(data.y.numpy())  # used for output_dim

I created multiple graph datasets myself for testing, only to find out something surprising… I’ll get back to that.

Neural Network Recap

A neural network is composed of layers: you have an input layer, then hidden layers, and finally the output layer. Each layer consists of neurons, and they are connected to the next layer via weights. The goal of a neural network is to optimize the weights using the training data. When data passes through the neurons, the outputs are multiplied by weights, summed, and transformed by activation functions at each neuron. The output layer produces the final prediction. Other important aspects of a neural network are the loss function, which measures the difference between the predicted and actual outcome (or the error), and back propagation for updating the weights of the network (with an optimizer like Adam).

Recap of a basic neural network. Image by author.
Recap of a basic neural network. Image by author.

Let’s train a normal neural network on the datasets. So for this, we don’t use graph information, we just use the features of every node to predict the target. The following code shows the class for the neural network, it has 2 hidden layers.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.lin1 = nn.Linear(input_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x = data.x  # no graph structure, only node features
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return F.log_softmax(x, dim=1)

We use the following code for training and evaluating the models:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
results = {}

# iterate over the different model types
for model_class in [MLP]: # later we test also with GCN (this post) and GAT (next blog post)
    results[model_class.__name__] = []
    for i in range(10):
        print(f"Training {model_class.__name__} iteration {i+1}")

        # the output_dim is the number of unique classes in the set
        model = model_class(input_dim=data.x.shape[1], hidden_dim=32, output_dim=num_labels).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        # deal with the class imbalance
        class_weights = torch.bincount(data.y) / len(data.y)
        loss_fn = nn.CrossEntropyLoss(weight=1/class_weights).to(device)

        data = data.to(device)

        # training loop
        for epoch in range(100):
            model.train()
            optimizer.zero_grad()
            out = model(data)

            # calculate loss
            train_loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
            acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
            train_loss.backward()
            optimizer.step()

            if epoch % 10 == 0:
                model.eval()
                with torch.no_grad():
                    val_loss = loss_fn(out[data.val_mask], data.y[data.val_mask])
                    val_acc = accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])
                    print(f'Epoch {epoch} | Training Loss: {train_loss.item():.2f} | Train Acc: {acc:>5.2f} | Validation Loss: {val_loss.item():.2f} | Validation Acc: {val_acc:>5.2f}')

        # final evaluation on the test set
        model.eval()
        with torch.no_grad():
            out = model(data)
            test_loss = loss_fn(out[data.test_mask], data.y[data.test_mask])
            test_acc = accuracy(out[data.test_mask].argmax(dim=1), data.y[data.test_mask])
            print(f'{model_class.__name__} Test Loss: {test_loss.item():.2f} | Test Acc: {test_acc:>5.2f}')
            results[model_class.__name__].append([acc, val_acc, test_acc])

# print average on test set and standard deviation
for model_name, model_results in results.items():
    model_results = torch.tensor(model_results)
    print(f'{model_name} Test Accuracy: {model_results[:, 2].mean():.2f} ± {model_results[:, 2].std():.2f}')

So we train 10 times and calculate the average accuracy and standard deviation. The output for the MLP class is:

MLP Test Accuracy: 54.35 ± 1.06

Let’s see if we can improve this result with a Graph Convolutional Network!

Graph Convolutional Networks

How can we add graph information to the basic neural network? For understanding one node, we need to look at its neighborhood and include that information in the GNN. Remember that a linear layer from a normal neural network is written as:

X is the input matrix, and W is the weight matrix. For simplicity we leave the biases out.

In graphs there are different ways to define the edges. One way is with an adjacency matrix. By multiplying the adjacency matrix with the input matrix X, the features from the neighbor nodes will be summed. In the following example calculations, we will look at the following graph (with corresponding adjacency matrix):

Graph with adjacency matrix. Image by author.
Graph with adjacency matrix. Image by author.

But I can hear you thinking: What about the features of the node itself? Yes, we also need to include the features from the node itself, and we can do this by adding loops (edges from a node to itself). The adjacency matrix needs to have ones on the diagonal to contain loops, so let’s add an identity matrix to the adjacency matrix.

Adjacency matrix + identity matrix = adapted adjacency matrix.
Adjacency matrix + identity matrix = adapted adjacency matrix.

Now we can add this updated matrix to the linear layer:

And now we have a graph linear layer! We can replace the normal linear with this one, then we have our first graph neural network. But there is one important step we should take before actually testing it, and that is normalization. Imagine, without normalization, nodes with more connections (e.g. one node having 10 neighbors vs. another with just 1) can dominate the learning process. The node with 10 neighbors would aggregate far more information than the one with 1, leading to imbalance and unstable learning. Normalization ensures that each node’s contribution is appropriately scaled, so the network learns from the graph structure rather than being skewed by uneven data distribution.

How does the normalization step work? In GNNs it’s common to use symmetric normalization. The idea is to normalize each node’s aggregated features by the square root of its degree (the number of neighbors, including itself for self-loops). This helps to ensure that nodes with different degrees contribute equally during aggregation. First we need the degree matrix D:

Degree matrix + identity matrix (for self loops) to get the adapted degree matrix we can use for normalization.
Degree matrix + identity matrix (for self loops) to get the adapted degree matrix we can use for normalization.

We can normalize D by doing:

We have multiple options for normalization, for example:

But instead of normalizing across rows or columns, it’s better to balance between rows and columns (symmetric normalization). This ensures that both the source and destination nodes in an edge are treated equally, which prevents biases toward high-degree nodes and helps maintain stability in the learning process. The symmetrically normalized (adapted) adjacency matrix is computed like this:

And if we combine everything, we get the linear layer of the GCN:

Let’s test this! We can train the GCN on the dataset and calculate the accuracy, just like we did before with the MLP model. In PyTorch, we can use the GCNConv layer for this:

class GCN(torch.nn.Module):    
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

Now we train the model (with the same code as before, we can add GCN to the list with model classes.

The output is:

GCN Test Accuracy: 78.76 ± 0.38

Wow, that looks amazing! The accuracy with the normal neural network was 54.35, and now it’s 78.76! And the only thing we changed was adding the edge information and changing the layers of the model.

Considerations

The example above works perfectly and because of that you might think you should change from normal neural network to GCN for every dataset you can put into a graph. But it’s important to be aware that these amazing results don’t occur all the time. I tested with multiple graph datasets, some I created by myself (based on beer review dataset and Amazon review dataset), and others were benchmark datasets (citeseer, pubmed, ogb). For most of the benchmark datasets the results were improving when switching from a normal neural net to a graph neural net, but for some of them and for the datasets I created myself the results didn’t improve. Below some considerations related to graph neural networks:

  • The graph structure should really make a difference for the problem you are trying to solve. The structure should be meaningful for the prediction task at hand. Testing is important here. You can try to formulate the graph in different ways to see if one way of formulating works better than another one.
  • Training a graph neural network takes more time than training a normal neural network. So if the results improve only a little bit and training time is important, the normal neural network can be the best choice. Also, the effectiveness among types of graph neural networks (GCN, GAT, GraphSAGE) can vary greatly based on the problem.
  • Traditional neural networks can be efficiently batched during training. For graph neural networks, it’s harder to batch the data because nodes have different neighbors, resulting in potentially uneven mini-batches. Efficient sampling techniques (like GraphSAGE) or mini-batch training are necessary for scalability.
  • Just like in standard neural networks, transfer learning (pre-training a GNN on a large dataset and fine-tuning on the target dataset) can be effective for GNNs. Checking for available pre-trained models for your task can be valuable.

Conclusion

As we’ve seen, simply adding graph information to a basic neural network can dramatically boost performance, as was the case when we moved from a normal neural network to a GCN for the Cora dataset. By aggregating information from neighboring nodes, GCNs can provide a richer representation of the data, leading to more accurate predictions. But, it’s crucial to remember that GNNs aren’t a magic bullet for every problem. The graph structure must be truly meaningful to the prediction task, and the increase in training complexity might not always justify the performance boost, especially when training time is critical.

Experimenting with different graph formulations and GNN architectures, like Graph Attention Networks (GATs), can lead to even better results depending on your dataset. In the next post, we’ll dive deeper into GATs and see how attention mechanisms can enhance GNNs by learning to weigh the importance of neighboring nodes.

Related

Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs

Optimizing Connections: Mathematical Optimization within Graphs

Mathematical Optimization Heuristics Every Data Scientist Should Know

The post Graph Neural Networks Part 1. Graph Convolutional Networks Explained appeared first on Towards Data Science.

]]>
Combining Storytelling and Design for Unforgettable Presentations https://towardsdatascience.com/combining-storytelling-and-design-for-unforgettable-presentations-e9751e0d90d3/ Thu, 18 Apr 2024 17:34:35 +0000 https://towardsdatascience.com/combining-storytelling-and-design-for-unforgettable-presentations-e9751e0d90d3/ How to craft slide decks that stand out

The post Combining Storytelling and Design for Unforgettable Presentations appeared first on Towards Data Science.

]]>
Sometimes you work on projects that you have to share with the world (or your company). These projects make an impact, and by sharing them you can get more support or show the value you bring. It can be a challenge to tell the story in a good way. In this post, you’ll get some guidelines that can help you to create beautiful slide decks. In the end, you can apply the tips to your own project.

Note: These tips are applicable to many different types of presentations. And not only to presentations, you can also use them to create a shareable document that tells the story of a project. If you create the slides in the right way, you can share them directly with other people.


The Two Main Principles of a Storytelling Slide Deck

You can find many tips regarding Storytelling on the internet. Some of them are important, e.g. you shouldn’t use jargon to make it understandable for everyone, and of course you should create appealing visualizations that support the story. Such things make sense and I hope this is not a problem for most professionals.

In my career I sometimes struggled with ‘telling the story’. Multiple times during my presentations I saw people drop or lose interest. Early in my career, I had a one-size-fits-all mentality, and told the same story to technical and business people. As you can imagine, that didn’t work. Now, I adjust my slide decks to the audience. If I am invited to a large meeting to share project results, I ask the goal of the meeting and the audience to adjust accordingly.

Lately, I received a different kind of question. My manager asked to create a slide deck about my project to share broadly in the company. That was new, and it introduced new difficulties, such as: "How can I explain my project understandable to everyone without being there and presenting it?" and "What should I include and exclude in the story, while making it coherent and tell everything that’s important?" I started in the wrong way: My first slide was full of text to serve as a summary of the project and results. Luckily, my manager explained two principles that helped a lot in improving the slide deck. These two principles changed the way I look at slide decks, storytelling and presentations.

I hope you can use them as well, to be able to share your projects with the world in a good way!

Principle 1. Horizontal Alignment

The first principle is horizontal alignment. Each slide within a presentation should function as a coherent unit, showing one single aspect of the overall narrative. You shouldn’t just label the slide, but instead make a statement by putting the core message in the title. Instead of generic headings like "Benefits", "Improvements", **** or "Summary", the title should state the key takeaway or insight you want to tell at that particular slide.

Then, the visualizations (charts, graphs, images, diagrams, text boxes, etc.) on the slide should directly support or explain the message from the title. By doing this, you pair the content of the slide with the title. Every slide will be self-contained and deliver one distinct part of the larger story.

The title of a slide should state the key takeaway or insight you want to tell at that particular slide, and the visualizations should support that message.

This approach has many benefits. The biggest one is that it makes sure the communication is clear, because each slide focuses on one main idea. This makes it easier for the audience to grasp and remember the information.

Horizontal alignment. Aligning the content of one slide. Image by author.
Horizontal alignment. Aligning the content of one slide. Image by author.

For creating the charts on a slide, I can recommend the book Storytelling with Data. Besides the book, they have many resources online, like the chart guide and makeover examples. Some practical tips they give:

  • Remove clutter. Every element you add to a slide, will add cognitive load to the audience. By removing elements and choosing colors selective (e.g. only for highlighting relevant information) you can help the audience a lot.
  • Choose an effective chart. It can be simple text, a slope graph, line graph, bar chart, histogram, table, heat map or scatterplot. Make sure the chart supports the message and is easy to understand.

Principle 2. Vertical Alignment

The second principle, vertical alignment, takes a macro view of the presentation structure. It means that if you read all the titles of the slides in sequence, these titles should tell the full story of the presentation. This makes a coherent and comprehensive narrative.

This strategy requires careful planning and organization: Each slide title should serve its individual purpose, but also contribute to the overarching message. It should be a logical progression from introducing the subject, through development of key points or arguments, and then conclude with a strong, memorable closing statement.

When reading all the titles of the slides in sequence, these titles should tell the full story of the presentation.

You can see vertical alignment as offering a backbone to your presentation, as it ensures that the content flows in a logical manner. This helps the audience to follow along more easily, you connect the dots between individual slides to form an understanding of the subject. By doing this correctly, your presentation will be more impactful and memorable.

Vertical alignment. Aligning the content of all slides. Image by author.
Vertical alignment. Aligning the content of all slides. Image by author.

That’s it! Sounds easy, right? To be honest, it takes time to create a well thought and good story. Revisiting, asking feedback, and iterating to improve the slide deck is very normal and in my opinion preferred. Let’s apply the principles to an example.


Example: Optimizing cargo and baggage handling operations

The example we will take a look at is about the (fictional) airplane company BlueHorizon Flights. They launched a project that optimizes baggage handling operations. The results were quite good and after the proof of concept, it’s time for a go/no go. The goal of the presentation is to convince stakeholders to continue with the project.

Instead of doing the following:

  • Creating slides with names like: Introduction, Technical Details and Results in Numbers
  • Add lots of text to the slides (and by doing that making it hard to follow for the audience).
  • Adding many graphs to the result slide, jumping from result to result.
  • Provide technical details for a non-technical audience.

We will:

  • Give every slide a title with the key point of the slide.
  • Make sure all the titles in sequence tell the story of the presentation (vertical alignment).
  • Make sure the visualizations on a single slide support the title (horizontal alignment).

Vertical Alignment: Fixing the Titles

Let’s start with a ‘bad’ example, so we can fix the issues. It’s quite common to create titles of slides like this:

Slide titles. Image by author.
Slide titles. Image by author.

And of course it’s not always bad, there might people who can give a great presentation using a structure like this. But I believe a presentation will be better and easier to understand for the audience by applying the horizontal and vertical alignment tips. Let’s create the titles of the new slide deck. They might look like this:

Slide titles telling the story. Image by author.
Slide titles telling the story. Image by author.

Just like that, we now have vertical alignment in place! Do you already got ideas for visualizations on the slides? Sometimes it can be hard to make the titles not too long. You can try to move some less important parts of the title to text boxes, or rephrase the text. Some title examples can also be written with less words (e.g. is it really necessary to say AI-driven Dynamic Resource Allocation Software?). You can also come to the conclusion that two titles work better combined on one slide. As I said, it’s an iterative process, and during slide creation, you can adapt and change if something else works better. Just make sure the titles work well together and tell the story.

Horizontal Alignment: Support the Title with Visualizations

Besides the titles, we will dive into one slide to ‘fix’. This gives you an idea on how to approach horizontal alignment.

Here is an ugly problem introduction slide:

Ugly problem introduction slide. Image by author (not proud of it).
Ugly problem introduction slide. Image by author (not proud of it).

I hope you don’t like it. Okay, what is wrong here? The title of course, and the long sentences. Also, the charts are ugly and hard to read: a dual-axis chart and pie charts are (almost) never a good idea (yes, even 2D pie charts are often a bad choice). Can you tell me whether the orange slice of the pie or the dark blue one is the biggest? Another issue is that the pie chart doesn’t support the message directly. It’s interesting to know the causes of the mishandled luggage, but it doesn’t relate directly to the text.

How can we fix these issues? Actually, there are many ways! Keep in mind that the following slide is just an example to give you an idea on how to apply the horizontal alignment principles:

Improved version of the problem introduction slide. Click to enlarge. Image by author.
Improved version of the problem introduction slide. Click to enlarge. Image by author.

This slide has a clear title that directly gets to the point: the problem with the luggage handling. The three charts support this message and they are very easy to understand. The first one shows that BlueHorizon has the highest costs compared to similar airlines. The BlueHorizon bar is marked, to directly point your eyes towards that part of the graph. The second chart is an improved version of the dual-axis chart from the previous slide. By giving them the shared x-axis but separate y-axis the chart is easier to read. Clutter is removed, like the grid lines and colors. The text below the charts explains them. These three main points are interesting to stakeholders: high costs, increasing rate in mishandled luggage, and an increase in complaints from customers and coworkers.

One final note: when you are presenting, it’s often useful to add animations. This means that instead of showing the full slide directly, it can help to add parts to the slide one by one (in the previous example, chart by chart). This makes it harder for people to get distracted by everything on the slide. Also, sharing a deck without presenting might need a bit more text than when you are there and can tell everything.

For the 1% of people who is curious about the answer to the pie chart question; the dark blue slice is the biggest one.


Conclusion

Combining horizontal and vertical alignment principles can help a lot in creating effective and engaging presentations. Each slide on its own contributes to the story on a micro level (horizontal alignment), while all slides together tell the full story (macro level, vertical alignment). This enhances the clarity, impact and persuasiveness of your presentations.

Good luck with crafting your slide decks and give even better presentations!

Related

Simplify Your Machine Learning Projects

These are the Steps You Need to Take to Create a Good Data Science Product

Three Essential Soft Skills for Practical Data Scientists

The post Combining Storytelling and Design for Unforgettable Presentations appeared first on Towards Data Science.

]]>
Monte Carlo Methods Decoded https://towardsdatascience.com/monte-carlo-methods-decoded-d63301bde7ce/ Thu, 15 Feb 2024 17:34:35 +0000 https://towardsdatascience.com/monte-carlo-methods-decoded-d63301bde7ce/ Solving complex problems with simulations

The post Monte Carlo Methods Decoded appeared first on Towards Data Science.

]]>
Monte Carlo Methods, Decoded

The name "Monte Carlo" for the Monte Carlo methods has an origin that ties back to the famous Monte Carlo Casino located in Monaco. This name was not chosen because of any direct association with the mathematical principles behind these methods, but rather for its metaphorical connection to randomness and chance, which are central elements in both gambling and Monte Carlo simulations. In this post, we will discuss this technique and show code examples related to project management, approximating irregular areas, and gaming.

Real-world systems and processes often involve uncertain parameters and variables. With Monte Carlo methods you can explicitly model these uncertainties. Businesses can make better informed decisions by understanding the probability and impact of different risks. Besides decision support, you can use it for enhancing predictive models and/or communication.


The Basics

Imagine you have a big, mysterious jar full of different-colored marbles. There is one problem: you can’t see inside it to count how many of each color there are. You want to know which color you’re most likely to pick if you reach in the jar without looking.

You will play a game to figure this out. Instead of trying to guess or count the marbles directly (which you can’t do because you can’t see inside), you close your eyes, reach into the jar, and pull out a marble. You look at the color, write it down, and then put it back. You do this many, many times: pulling out a marble, writing down the color, and putting it back.

After you’ve done this, you look at the list. If you pulled out 100 marbles and 60 of them were red, you might start to think, "Hmm, there are probably a lot of red marbles in the jar, so I’m more likely to get a red one than any other color when I reach in."

Monte Carlo methods are all about learning something when you can’t see all the details directly. The example above is really easy: you don’t need any inputs for the model. You just start sampling to be able to draw conclusions about the outputs. Monte Carlo can be a very effective technique if you repeat your experiment for enough iterations.

Sampling, Distributions, Iterations and Convergence

To summarize, there are three key concepts in Monte Carlo Methods. The first one is random variable sampling; for different input variables you generate values from probability distributions. The probability distributions represent the uncertainty about those variables. Common distributions include normal (Gaussian), uniform, log-normal, etc., which describe the likelihood of occurrence of different values. The third concept is iterations and convergence. Running the simulation multiple times (hundreds, thousands, or more) to ensure the results converge to a stable solution, providing a robust statistical basis for predictions.

Distributions visualized. Image by author.
Distributions visualized. Image by author.

3 Example Applications

There are many applications of Monte Carlo methods. In fact, Monte Carlo is an umbrella term: It has a broad applicability and versatility across various domains and disciplines. Monte Carlo methods contain a wide range of computational algorithms that rely on repeated random sampling to obtain numerical results. These methods are used to solve problems in fields as physics, engineering, statistics, finance, and computer science.

Let’s take a look at three common use cases and start coding! We start with Monte Carlo simulations applied in project management. After this, we will look at a completely different example: approximation of the area under (irregular) curves. We end with a technique used in game playing agents: Monte Carlo tree search.

Launch New Software

You are a project manager and responsible for the launch of a new software product, including development, testing, and market release phases, with the goal of completing the project within 6 months (180 days) from project initiation.

The launch is divided in six different tasks, and you discussed with involved people how long they will take approximately. This gives you the following distributions per task:

Distributions per task. Normal, triangular and uniform distributions. Click to enlarge. Image by author.
Distributions per task. Normal, triangular and uniform distributions. Click to enlarge. Image by author.

Now it’s time to run the simulation! You randomly sample 10000 times from the distributions, and add the values. Deployment and marketing can run in parallel, so you take the maximum value for those tasks. After sampling, you plot the total distribution, together with the target goal:

Distribution of total completion time for the project based on the simulations. Image by author.
Distribution of total completion time for the project based on the simulations. Image by author.

You can draw conclusions from this plot and decide if 180 days is a realistic goal for the project. Some statistics based on this simulation:

Mean Total Completion Time: 191.94 days
95% Confidence Interval: 163.04 - 221.18 days
Probability of completing the project within 6 months: 21.04%

These results can help you communicating to stakeholders. A more realistic schedule might be 200 days. Always be aware that it is hard or impossible to model all uncertainties (we’ll get back to the cons of MCS later in this post).

Code:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

np.random.seed(42)

n_simulations = 10000

# sample the observations from different distributions per task
requirements_distribution = np.random.normal(30, 5, n_simulations)
design_distribution = np.random.triangular(20, 35, 50, n_simulations)
development_distribution = np.random.normal(60, 10, n_simulations)
testing_distribution = np.random.triangular(25, 40, 60, n_simulations)
deployment_training_distribution = np.random.uniform(20, 30, n_simulations)
marketing_distribution = np.random.uniform(15, 25, n_simulations)

# calculate the total completion time based on the observations
total_completion_time = requirements_distribution + design_distribution + development_distribution + testing_distribution + np.maximum(deployment_training_distribution, marketing_distribution)

# plot and print results
plt.figure(figsize=(10, 6))
sns.histplot(total_completion_time, kde=True, color="blue", binwidth=5)
plt.title('Total Project Completion Time')
plt.xlabel('Total Days')
plt.ylabel('Frequency')
plt.axvline(180, color='red', linestyle='--', label='6 Months Target')
plt.legend()
plt.show()

print(f"Mean Total Completion Time: {np.mean(total_completion_time):.2f} days")
print(f"95% Confidence Interval: {np.percentile(total_completion_time, 2.5):.2f} - {np.percentile(total_completion_time, 97.5):.2f} days")
print(f"Probability of completing the project within 6 months: {np.mean(total_completion_time &lt;= 180):.2%}")

Area under an Irregular Curve

Do you remember the rules for calculating the area under a complex curve like:

I bet you don’t! Maybe only if you are a mathematician or need to do calculations like this often. Or when you like to practice mathematical integration skills in your free time.

There are situations where it becomes impractical or impossible to calculate the area under a curve using traditional analytical mathematical methods. This can be the case in high-dimensional problems, or when the domain of integration has an irregular shape or is defined by conditions that are difficult to express mathematically. Sometimes, an approximation is good enough for the problem at hand. When it’s hard to compute the area under the curve, using Monte Carlo simulations is easy and can help to get a fast answer to the problem.

In this example we will estimate the area under the curve of the function above on the domain [0, 2π]. The curve looks like this:

Curvy curve. Image by author.
Curvy curve. Image by author.

The true area is equal to 8.2683. How can we estimate the area under the curve by using Monte Carlo simulations? We take the following steps:

  1. We sample n random points from a uniform distribution that fall into the domain and range of the function.
  2. We determine the number of points that fall under the curve y < f(x).
  3. Here comes the trick: We multiply the domain with the range (total area of the ‘rectangle’), and multiply this with the percentage of points that falls under the curve. This gives us a good approximation.

Let’s try it in practice. We sample a number of points and calculate the estimation, as well as the error. For the error we subtract the estimation from the true area.

Approximation of area under a curve with Monte Carlo simulations. Image by author.
Approximation of area under a curve with Monte Carlo simulations. Image by author.

By adding points we get closer and closer to the actual value. Using 1 million random points, the error is only 0.0039. (This doesn’t have to be the case though, because of the randomness of Monte Carlo, but most of the times, the approximation should improve by adding points.) This approach is much simpler than mastering calculus!

Approximation of areas can be useful in more scenarios than you might think. One example is when estimating the coverage of ice sheets, forests, and bodies of water to understand climate change impacts. Or an example in astronomy: people can determine the area of celestial objects’ surfaces or the cross-sectional area of galaxies when direct measurements are not possible.

Code related to this example:

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(31)

def f(x):
    return np.sin(x) + np.cos(2*x) + x**2 / 10

x_range = np.linspace(0, 2*np.pi, 10000)
true_area = np.trapz(f(x_range), x_range)

n_points_list = [100, 1000, 10000, 100000, 1000000]

fig, axs = plt.subplots(3, 2, figsize=(12, 15))
fig.subplots_adjust(hspace=0.3, wspace=0.2)

# plot 1: The curve and the actual area under the curve
axs[0, 0].plot(x_range, f(x_range), color='black', linewidth=2, label='f(x)')
axs[0, 0].fill_between(x_range, 0, f(x_range), alpha=0.3, color='skyblue', label='True Area')
axs[0, 0].set_title('Curve and areanTrue area = {:.4f}'.format(true_area))
axs[0, 0].legend()
axs[0, 0].set_xlabel('x')
axs[0, 0].set_ylabel('f(x)')

# simulation for different numbers of points
for idx, n_points in enumerate(n_points_list, start=1):
    x = np.random.uniform(0, 2*np.pi, n_points)
    y = np.random.uniform(0, max(f(x_range)), n_points)
    f_x = f(x)
    points_under_curve = y &lt; f_x
    area_estimate = 2 * np.pi * max(f(x_range)) * np.sum(points_under_curve) / n_points

    ax = axs.flat[idx]
    ax.scatter(x[~points_under_curve], y[~points_under_curve], color='red', s=1, label='Above Curve')
    ax.scatter(x[points_under_curve], y[points_under_curve], color='blue', s=1, label='Under Curve')
    ax.plot(x_range, f(x_range), color='black', linewidth=2, label='f(x)')
    ax.set_title(f'n points = {n_points}nEstimation = {area_estimate:.4f}; Error = {abs(area_estimate - true_area):.4f}')
    ax.set_xlabel('x')
    ax.set_ylabel('y')

    if idx == 1:
        ax.legend(loc='upper left')

plt.tight_layout()
plt.show()

Monte Carlo Tree Search

When creating agents playing games, Monte Carlo tree search (MCTS) is the algorithm you will find in many solutions. In AlphaGo and AlphaZero, MCTS is combined with neural networks to determine the best next action.

As an example, we will look at a really easy two player game. The game is called Island Conquest. There are four islands on the board, they are numbered 1–4. On each turn, a player claims one island. A player wins if he owns islands 1&2 or 1&4. Automatically, you lose if you own islands 2&3 or 3&4. If one player owns islands 1&3 and the other player 2&4, it’s a tie.

Island Conquest game board. A player wins when he/she claimed islands 1&2 or 1&4. Image created with Dall·E.
Island Conquest game board. A player wins when he/she claimed islands 1&2 or 1&4. Image created with Dall·E.

This game is super easy and would be boring to play, because the first player can always win by claiming island 1 in the first turn, and in his second turn claim island 2 or 4. For demonstrating MCTS, it’s a good game because of its simplicity. We can visualize this game as a tree:

Game tree for Island Conquest. At the root node, no island is claimed. In turn 1, player 1 can claim island 1, 2, 3 or 4. The best choice in the first turn is claiming island 1. Image by author.
Game tree for Island Conquest. At the root node, no island is claimed. In turn 1, player 1 can claim island 1, 2, 3 or 4. The best choice in the first turn is claiming island 1. Image by author.

Before we run the algorithm, let’s go over the theory. How can MCTS help us in deciding the next best action for a game? The MCTS algorithm consists of four steps:

Step 1. SelectionFirst step is to select a node (= an action) of the game tree. We start at the root node. In the beginning, we select nodes that are unvisited:

Selecting unvisited nodes. Island 1 is not visited so will be selected as action. Image by author.
Selecting unvisited nodes. Island 1 is not visited so will be selected as action. Image by author.

When we don’t have unvisited nodes left, we select nodes in a smarter way. Then we use the Upper Confidence bound for Trees (UCT). This is a mathematical formula that combines the number of visits with the score for the node. The upper confidence bound formula looks like this:

score / child_visits + math.sqrt(2 * math.log(parent_visits) / child_visits)

If the number of visits and the corresponding scores are like this:

Selection when all nodes are visited. Image by author.
Selection when all nodes are visited. Image by author.

We have a total visit of 27. If we need to select our first action (from the root), the scores for the islands are:

  • Island 1: 1.61 child_visits = 10; score = 8; parent_visits = 27
  • Island 2: 1.28 child_visits = 8; score = 3; parent_visits = 27
  • Island 3: 1.32 child_visits = 2; score = -1; parent_visits = 27
  • Island 4: 1.26 child_visits = 7; score = 2; parent_visits = 27

We select the highest value, so island 1 will be selected. The upper confidence bound balances exploration and exploitation.

Note that the value for island 3 is higher than for islands 2 & 4. We didn’t explore island 3 that much yet (only 2 times), so it might be interesting to try or simulate this action a few more times to become more certain of how bad this action is.

Step 2. ExpansionDuring the expansion phase, the selected node is expanded by adding one or more child nodes to the tree, representing possible future states of the game from that point. This step only occurs if the node has unexplored actions available. The new node is added to the tree with initial scores and visit counts, ready for future simulations.

In the beginning we only have the root node for the tree. By playing, we add nodes. Image by author.
In the beginning we only have the root node for the tree. By playing, we add nodes. Image by author.

Step 3. SimulationThis step involves simulating a play-out from the state represented by the expanded node. The simulation randomly selects actions for the player until a terminal state is reached (win, loss, or draw). The outcome of this simulation provides a score that indicates the desirability of the initial move, which is used to update the game tree in the backpropagation step. In the Island Conquest game, we use 1 for a win, -1 for a loss, and 0 for a draw.

Random simulation starting after the first action. Image by author.
Random simulation starting after the first action. Image by author.

Step 4. BackpropagationIn the final step, the results from the simulation are propagated back up the tree to the root. Every node visited during the selection and expansion phases is updated to reflect the new information: the visit count is incremented, and the score for every node is adjusted based on the outcome of the simulation. This updated information influences future selection decisions, gradually improving the algorithm’s decision-making as more simulations are conducted.

Backpropagation: updating the score and the number of visits to all the nodes we visited. Image by author.
Backpropagation: updating the score and the number of visits to all the nodes we visited. Image by author.

Do you recognize how MCTS relates to Monte Carlo methods? Again, iterations are used to converge to the optimal action. The difference with the previous examples, is that MCTS applies the Monte Carlo philosophy to processes in environments with sequential decisions, such as board games. MCTS is not just simulating outcomes of a static set of variables (like in the project management and area under the curve examples). It is dynamically building a tree of possible future states based on the outcomes of those simulations. At each iteration, MCTS selects a path through the tree based on past results (selection), expands the tree by exploring new moves (expansion), simulates outcomes from that position (simulation), and then updates the tree with the new information (backpropagation).

After simulation, MCTS will recommend the action with the highest number of visits, because this action is the one with the highest score.

Let’s code the Island Conquest game and add MCTS to select the best possible action for every turn. The first file is islandconquest.py:

from montecarlosimulation.islandconquest.mcts import mcts

class IslandConquest:
    def __init__(self):
        """ 
        Initialize the game with all islands unclaimed
        0 for unclaimed, 1 for Player 1, -1 for Player 2
        Player 1 starts
        """
        self.islands = [0, 0, 0, 0]
        self.player_turn = 1 

    def claim_island(self, island):
        """
        Player claims an island. Island numbers are 1-4.
        Switch turns after a valid move.
        """
        if self.islands[island-1] == 0:
            self.islands[island-1] = self.player_turn
            self.player_turn *= -1  
            return True
        else:
            return False  # Island already claimed

    def check_win(self):
        """
        Mapping the win/lose/tie conditions to sums of island states.
        Calculate the sum of states for comparison with the conditions.
        Return the result of the game (winner/tie/game continues).
        """
        win_conditions = [(1, 4), (1, 2)]  
        lose_conditions = [(2, 3), (3, 4)]

        island_sums = {tuple: sum(self.islands[i-1] for i in tuple) for tuple in win_conditions + lose_conditions}

        for condition in win_conditions:
            if island_sums[condition] == len(condition):
                return "Player 1 wins"
            elif island_sums[condition] == -len(condition):
                return "Player 2 wins"

        for condition in lose_conditions:
            if island_sums[condition] == len(condition):
                return "Player 2 wins"
            elif island_sums[condition] == -len(condition):
                return "Player 1 wins"

        if self.islands[0] + self.islands[2] in [-2, 2] and self.islands[1] + self.islands[3] in [-2, 2]:
            return "It's a tie"

        return "Game continues"

    def is_game_over(self):
        """
        Check if the game is over.
        """
        return self.check_win() != "Game continues"

    def current_state(self):
        """
        Print the current state of the game.
        """
        state = " ".join(["1" if x == 1 else "2" if x == -1 else "0" for x in self.islands])
        print(f"Islands: {state}")
        print(f"Player {'1' if self.player_turn == 1 else '2'}'s turn")
        print(self.check_win())

    def get_legal_moves(self):
        """
        Get the list of unclaimed islands.
        """
        return [i+1 for i, x in enumerate(self.islands) if x == 0]

    def copy(self):
        """
        Used in MCTS to copy the game state.
        """
        new_game = IslandConquest()
        new_game.islands = self.islands[:]
        new_game.player_turn = self.player_turn
        return new_game

    def get_result(self, player):
        """
        Get score of the game for the player for MCTS.
        """
        result = self.check_win()
        if result == "Player 1 wins" and player == 1 or result == "Player 2 wins" and player == -1:
            return 1
        elif result == "Player 1 wins" and player == -1 or result == "Player 2 wins" and player == 1:
            return -1
        return 0  # Tie or game continues

def play_game_with_mcts(game, mcts_iterations=1000):
    """
    Play the game with MCTS.
    Selects the best move for the current player and game state and performs it.
    """
    while not game.is_game_over():
        print("nCurrent game state:")
        game.current_state()

        best_move = mcts(game.copy(), mcts_iterations)
        print(f"Recommended move: Claim island {best_move}")

        game.claim_island(best_move)

        if game.is_game_over():
            print("nFinal game state:")
            game.current_state()
            break

if __name__=="__main__":
    game = IslandConquest()
    play_game_with_mcts(game)

This code is quite easy. First, we initialize a game with 4 unclaimed islands [0, 0, 0, 0]. Then it’s the turn for the first player, and this player can select an island. The best move for the current player is selected with Monte Carlo tree search.

The code from mcts.py is as follows:

import math
import random

class MCTSNode:
    def __init__(self, game_state, move=None, parent=None):
        """
        Initialize the node with the game state, move, and parent.
        """
        self.game_state = game_state
        self.move = move
        self.parent = parent
        self.children = []
        self.wins = 0
        self.visits = 0
        self.untried_moves = game_state.get_legal_moves()

    def select_child(self):
        """
        Select a child node with the highest UCT score.
        """
        return max(self.children, key=lambda c: c.wins / c.visits + math.sqrt(2 * math.log(self.visits) / c.visits))

    def add_child(self, move, state):
        """
        Remove the move from untried moves, create a new child node, and add it to children.
        """
        child = MCTSNode(game_state=state, move=move, parent=self)
        self.untried_moves.remove(move)
        self.children.append(child)
        return child

    def update(self, result):
        """
        Update this node - increment the visit count and update the win count based on the result.
        """
        self.visits += 1
        self.wins += result

def mcts(root_state, iterations):
    """
    MCTS algorithm. Selection, expansion, simulation, and backpropagation.
    """
    root_node = MCTSNode(game_state=root_state)

    for _ in range(iterations):
        node = root_node
        state = root_state.copy()

        # selection
        while node.untried_moves == [] and node.children != []:
            node = node.select_child()
            state.claim_island(node.move)

        # expansion
        if node.untried_moves:
            move = random.choice(node.untried_moves)
            state.claim_island(move)
            node = node.add_child(move, state)

        # simulation
        while not state.is_game_over():
            possible_moves = state.get_legal_moves()
            state.claim_island(random.choice(possible_moves))

        # backpropagation
        while node is not None:
            node.update(state.get_result(root_state.player_turn))
            node = node.parent

    for child in root_node.children:
        print(f"Move {child.move}: {child.visits} visits, {child.wins / child.visits:.2f} win rate")
    return sorted(root_node.children, key=lambda c: c.visits)[-1].move

The mcts function is called in the islandconquest.py script. Do you recognize the four steps of MCTS (selection, expansion, simulation and backpropagation)? In this code after simulation, the number of visits and scores for each move are printed. This shows the following result:

Current game state:
Islands: 0 0 0 0
Player 1's turn
Game continues
Move 1: 397 visits, 0.96 win rate
Move 2: 300 visits, 0.94 win rate
Move 4: 299 visits, 0.93 win rate
Move 3: 4 visits, -0.75 win rate
Recommended move: Claim island 1

Current game state:
Islands: 1 0 0 0
Player 2's turn
Game continues
Move 4: 486 visits, -0.02 win rate
Move 2: 503 visits, -0.02 win rate
Move 3: 11 visits, -1.00 win rate
Recommended move: Claim island 2

Current game state:
Islands: 1 2 0 0
Player 1's turn
Game continues
Move 4: 988 visits, 1.00 win rate
Move 3: 12 visits, 0.00 win rate
Recommended move: Claim island 4

Final game state:
Islands: 1 2 0 1
Player 2's turn
Player 1 wins

It found the optimal moves for the players for each turn! MCTS recommends island 1 to player 1 in the first turn, which actually is the best thing to do. The game can only result in a tie or in a win for player 1. In the second turn (player 2), it will recommend island 2 or 4, because otherwise player 1 is certain of a win (against a 50/50 chance that player 1 will select island 3 with a tie as result). You can see that the number of visits for island 2 and 4 are almost equal. In the third turn (player 1), island 4 is selected, because player 1 wins in this case.

Note: Do you see a possible improvement here? There is a big one. In step 3 of MCTS, the simulation step, we use a random play out. This approach may not accurately represent the behavior of a skilled opponent, leading to less optimal decision-making. There are ways to improve this, like training a policy network, using reinforcement learning, or an easier approach like using heuristic-based decisions. This is out of scope for this post, but if you are curious, you can find resources online explaining these techniques.


Cons

Besides all these cool applications of Monte Carlo, it’s not all sunshine. There are some cons, too. Let’s shortly go over them:

  • Monte Carlo methods can be computationally expensive. Especially when the problem is complex and you want to run many simulations, you need a powerful computer (and time). This relates to the number of inputs you have: if you add more unknowns, you want to run more iterations, and the computational cost can increase exponentially.
  • If the data is biased, incomplete, or inaccurate, the simulation results will be unreliable. There should be good data collection and analysis practices in place to be able to trust the results. This problem is also known as sensitivity to assumptions.
  • Building a comprehensive and accurate model can be challenging, especially for systems with complex interactions and dependencies. Simplifications made to manage complexity can sometimes lead to oversights or misrepresentations of the real-world scenario. Besides building the model, interpreting the results can be hard as well! There’s a risk of misinterpretation, especially when considering the range of possible outcomes and their probabilities.
  • You always need to keep in mind that simulations are not the same as precise predictions. Ensuring that the simulation has run a sufficient number of trials for the results to converge on a stable solution can be tricky. In some cases, reducing the variance of the results to achieve a desired level of precision might require an impractically large number of iterations.

Conclusion

Monte Carlo simulation is a useful technique, and can be applied in many different ways and fields. After reading this post, you learned the basics and key concepts of Monte Carlo. We went over three examples about project management, approximating areas, and Ai Gaming. Thanks for reading, and until next time!

Related

Solving Multi-Armed Bandit Problems

Snake Played by a Deep Reinforcement Learning Agent

Techniques to Improve the Performance of a DQN Agent

The post Monte Carlo Methods Decoded appeared first on Towards Data Science.

]]>
Is This the Solution to P-Hacking? https://towardsdatascience.com/is-this-the-solution-to-p-hacking-a04e6ed2b6a7/ Thu, 16 Nov 2023 17:34:36 +0000 https://towardsdatascience.com/is-this-the-solution-to-p-hacking-a04e6ed2b6a7/ E-values, a better alternative for p-values

The post Is This the Solution to P-Hacking? appeared first on Towards Data Science.

]]>
In scientific research, the manipulation of data and peeking at results have been problems for as long as the field has existed. Researchers often aim for a significant p-value to get published, which can lead to the temptation of stopping data collection early or manipulating the data. This practice, known as p-hacking, was the focus of my previous post. If researchers decide to deliberately change data values or fake complete datasets, there is not much we can do about it. However, for some instances of p-hacking, there might be a solution available!

In this post, we dive into the topic of safe testing. Safe tests have some strong advantages over the old (current) way of hypothesis testing. For example, this method of testing allows for the combination of results from multiple studies. Another advantage is that you can stop the experiment optionally, at any time you like. To illustrate safe testing, we will use the R package safestats, developed by the researchers who proposed the theory. First, we will introduce e-values and explain the problem they can solve. E-values are already used by companies like Netflix and Amazon because of their benefits.

I will not delve into the proofs of the theory; instead, this post takes a more practical approach, showing how you can use e-values in your own tests. For proofs and a thorough explanation of safe testing, the original paper is a good resource.


An Introduction to E-values

In hypothesis testing, which you can brush up on here, you assess whether to retain the null hypothesis or to accept the alternative. Usually, the p-value is used for this. If the p-value is smaller than the predetermined significance level, alpha, you accept the alternative hypothesis.

E-values function differently from p-values but are related. The easiest interpretation of e-values is like this: Suppose you are gambling against the null hypothesis. You invest 1$, and the return value is equal to E$. If the e-value E is between 0 and 1, you lose, and the null hypothesis holds true. On the other hand, if the e-value is higher than 1, you win! The null hypothesis loses the game. A modest E of 1.1 implies limited evidence against the null, whereas a substantial E, say 1000, denotes overwhelming evidence.

Some main points of e-values to be aware of:

  • An e-value can take any positive value, and you can use e-values as an alternative to p-values in hypothesis testing.
  • An e-value E, is interpretable as a traditional p-value p, by the relation 1/E = p. Beware: It will not give you the same result as a standard p-value, but you can interpret it like a p-value.
  • In traditional tests, you have alpha, also known as the significance level. Often this value is equal to 0.05. E-values work a bit different, and you can look at them as evidence against the null. The higher the e-value, the more evidence against the null.
  • At any point in time (!) you can stop data collection and draw a conclusion during the test if you are using e-values. This is known as e-processes, and the use of e-processes ensures validity under optional stopping and allows sequential updates of statistical evidence.

Fun fact: E-values are not as ‘new’ as you might think. The first paper on it was written in 1976. The values were not called e-values at that time.

A researcher gambling against... a hypothesis?! Image created with Dall·E 3 by the author.
A researcher gambling against… a hypothesis?! Image created with Dall·E 3 by the author.

Why should I care about E-values?

That is a valid question. What is wrong with traditional p-values? Is there a need to replace them with e-values? Why learn something new if there is nothing wrong with the current way of testing?

Actually, there is something wrong with p-values. There is a ton of criticism on traditional p-values. Some statisticians (over 800) want to abandon p-values completely.

Let’s illustrate why with a classic example.

Imagine you are a junior researcher for a pharmaceutical company. You need to test the efficacy of a medicine the company developed. You search for test candidates, and half of them receive the medicine, while the other half takes a placebo. You determine how many test candidates you need to be able to draw conclusions.

The experiment starts, and you struggle a bit finding new participants. You are under time pressure, and your boss asks on a regular basis, "Do you have the results for me? We want to ship this product to the market!" Because of the pressure, you decide to peek at the results and calculate the p-value, although you haven’t reached the minimum number of test candidates! Looking at the p-value, now there are two options:

  • The p-value is not significant. This means you cannot prove that the medicine works. Obviously, you don’t share these results! You wait a bit longer, hoping the p-value will become significant…
  • Yes! You find a significant p-value! But what is your next step? Do you stop the experiment? Do you continue until you reach the correct number of test candidates? Do you share the results with your boss?

After you looked at the data once, it’s tempting to do it more often. You calculate the p-value, and sometimes it’s significant, sometimes it isn’t… It might seem innocent to do this, while in fact, you are sabotaging the process.

Significant or not? Image created with Dall·E 3 by the author.
Significant or not? Image created with Dall·E 3 by the author.

Why is it wrong to only look at the data and the corresponding p-value a few times before the experiment has officially ended? One simple and intuitive reason is because if you would have done something with other results (e.g. if you find a significant p-value you stop the experiment), you are messing with the process.

From a theoretical perspective: You violate the Type I error guarantee. The Type I error guarantee refers to how certain you can be that you will not mistakenly reject a true null hypothesis (= find a significant result). It’s like a promise about how often you’ll cry wolf when there’s no wolf around. The risk of this happening is ≤ alpha. But only for one experiment! If you are looking at the data more often, you cannot trust this value anymore: the risk of a Type I error becomes higher.

This relates to the multiple comparisons problem. If you do multiple independent tests to proof the same hypothesis, you should correct the value of alpha to keep the risk of a Type I error low. There are different ways of fixing this, like Bonferroni, Tukey’s range test or Scheffé’s method.

The family-wise error rate for multiple independent tests. For one tests it is equal to alpha. Note that for 10 tests, the error rate has increased to 40%, and for 60 tests, it's 95%. Image by author.
The family-wise error rate for multiple independent tests. For one tests it is equal to alpha. Note that for 10 tests, the error rate has increased to 40%, and for 60 tests, it’s 95%. Image by author.

To summarize: P-values can be used, but it can be tempting for researchers to look at the data before the sample size is reached. This is wrong and increases the risk of a Type I error. To guarantee the quality and robustness of an experiment, e-values are the better alternative. Because of characteristics of e-values, you don’t need to doubt these experiments (or at least less, a researcher can always decide to fabricate data 😢 ).

Benefits of using E-values

As mentioned earlier, we can use e-values in the same way as p-values. One major difference in that case is that large e-values are comparable with low p-values. Recall that 1/E = p. If you want to use e-values in the same way as p-values and you use a significance level of 0.05, you can reject the null hypothesis if an e-value is higher than 20 (1/0.05).

But of course, there are more use cases and benefits of e-values! If there are several experiments that test the same hypothesis, we can multiply the e-values for those tests to get a new e-value that can be used for testing. This can never be done for p-values. But for e-values, it works.

You can also look at the data and results during the experiment. If you want to stop with the test, because the results don’t look promising, that’s okay. Another possibility is to continue with a test if it does look promising.

We can also create anytime valid confidence intervals with e-values. What does this mean? It means that the confidence intervals will work for any sample size (so during the whole experiment). They will be a bit broader than a regular confidence interval, but the good thing is that you can trust them at anytime.

Usage of the safestats package

In the last part of the post, we get more practical. Let’s calculate our own e-values. For this, we use the R-package safestats. To install and load it, run:

install.packages("safestats")
library(safestats)

The case we will solve is a classic one: We will compare different versions of a website. If a person buys, we log success (1), and if a person does not buy anything, we log failure (0). We show the old version of the website to 50% of the visitors (group A), and the new version of the website to the other 50% (group B). In this use case, wel will look at different things that can happen. It can happen that the null hypothesis is true (no difference between the website or the old version is better), and sometimes the alternative hypothesis is true (the new website is better).

The first step in creating a safe test is creating the design objective. In this variable, you specify values for alpha, beta and delta:

designObj <- designSafeTwoProportions(
  na = 1,
  nb = 1,        # na and nb are of equal size so 1:1
  alpha = 0.05,  # significance level
  beta = 0.2,    # risk of type II error
  delta = 0.05,  # minimal effect we like to detect
)

designObj

In many cases, delta is set to a higher number. But for comparing different versions of a website with a lot of traffic, it makes sense to set it small, because it’s easy to get many observations.

The output looks like this:

 Safe Test of Two Proportions Design

 na±2se, nb±2se, nBlocksPlan±2se = 1±0, 1±0, 4355±180.1204
              minimal difference = 0.05
                     alternative = twoSided
         alternative restriction = none
                 power: 1 - beta = 0.8
 parameter: Beta hyperparameters = standard, REGRET optimal
                           alpha = 0.05
decision rule: e-value > 1/alpha = 20

Timestamp: 2023-11-15 10:58:37 CET

Note: Optimality of hyperparameters only verified for equal group sizes (na = nb = 1)

You can recognize the values we chose, but the package also calculated the nBlocksPlan parameter. This is the number of data points (blocks) we need to observe, it’s based on the delta and beta parameter. Also check the decision rule, based on the value of alpha. If the e-value is greater than 20 (1 divided by 0.05), we reject the null hypothesis.

Test case: Alternative Hypothesis is True

Now, let’s generate some fake data:

set.seed(10)
successProbA = 0.05  # success probability for A 5%
successProbB = 0.08  # success probability for B 8%
nTotal = designObj[["nPlan"]]["nBlocksPlan"]  # use the nBlocksPlan value as sample size
ya <- rbinom(n = nTotal, size = 1, prob = successProbA)
yb <- rbinom(n = nTotal, size = 1, prob = successProbB)
Distribution A and B for success probabilities 0.05 and 0.08, respectively. Image by author.
Distribution A and B for success probabilities 0.05 and 0.08, respectively. Image by author.

It’s time to perform our first safe test!

safe.prop.test(ya=ya, yb=yb, designObj=designObj)

With output:

 Safe Test of Two Proportions

data:  ya and yb. nObsA = 4355, nObsB = 4355

test: Beta hyperparameters = standard, REGRET optimal
e-value = 77658 > 1/alpha = 20 : TRUE
alternative hypothesis: true difference between proportions in group a and b is not equal to 0 

design: the test was designed with alpha = 0.05
for experiments with na = 1, nb = 1, nBlocksPlan = 4355
to guarantee a power = 0.8 (beta = 0.2)
for minimal relevant difference = 0.05 (twoSided) 

The e-value is equal to 77658, which means we can reject the null hypothesis. Enough evidence to reject it!

A question that might arise: "Could we have stopped earlier?" That is a nice benefit of e-values. Peeking at the data is allowed before the planned sample size is reached, so you can decide to quit or continue an experiment at any time. We can plot the e-values, e.g. cumulative for every 50 new samples. The first 40 e-values plot:

In the beginning there is no evidence against the null, corresponding to low e-values. But with gathering more samples the evidence starts to show: the e-values exceed the decision boundary of 20. Image by author.
In the beginning there is no evidence against the null, corresponding to low e-values. But with gathering more samples the evidence starts to show: the e-values exceed the decision boundary of 20. Image by author.

The full plot:

We can be sure: the null hypothesis should be rejected. All e-values except the last one. Image by author.
We can be sure: the null hypothesis should be rejected. All e-values except the last one. Image by author.

Test case: Null Hypothesis is True

If we change the fake data and make the probabilities equal to each other (success probability for version A and B equal to 0.05), we should detect no significant e- or p-value. The distributions of version A and B look similar and the null hypothesis is true. This is reflected in the e-values plot:

No effect. Image by author.
No effect. Image by author.

But what if we compare this with p-values? How often will we reject the null hypothesis, although in reality we shouldn’t? Let’s test it. We will repeat the experiment 1000 times, and see in how many cases we rejected the null hypothesis for p-values and e-values.

The R code:

pValuesRejected <- c()
eValuesRejected <- c()
alpha <- 0.05
ealpha <- 1/alpha

# repeat the experiment 1000 times, calculate the p-value and the e-value
for (i in seq(1, 1000, by = 1)) {
  # create data, use the same value of nTotal as before (4355)
  set.seed(i)
  ya <- rbinom(n = nTotal, size = 1, prob = 0.05)
  yb <- rbinom(n = nTotal, size = 1, prob = 0.05)

  # calculate the p-value, H0 rejected if it's smaller than alpha
  testresultp <- prop.test(c(sum(ya), sum(yb)), n=c(nTotal, nTotal))
  if (testresultp$p.value < alpha){
    pValuesRejected <- c(pValuesRejected, 1)
  } else{
    pValuesRejected <- c(pValuesRejected, 0)
  }

  # calculate the e-value, H0 rejected if it's bigger than 1/alpha
  testresulte <- safe.prop.test(ya=ya, yb=yb, designObj=designObj)
  if (testresulte[["eValue"]] > ealpha){
    eValuesRejected <- c(eValuesRejected, 1)
  } else{
    eValuesRejected <- c(eValuesRejected, 0)
  }
}

And the output if we sum the pValuesRejected and the eValuesRejected:

> sum(pValuesRejected)
[1] 48
> sum(eValuesRejected)
[1] 0

The p-value was significant in 48 of the cases (around 5%, this is what we would expect with an alpha of 0.05). On the other hand, the e-value does a great job: It never rejects the null hypothesis. In case you weren’t convinced of using e-values yet, I hope you are now!

If you are curious for other examples, I can recommend the vignettes from the safestats package.

Conclusion

E-values present a compelling alternative to traditional p-values, offering several advantages. They provide the flexibility to either continue or halt an experiment at any stage. Additionally, their combinability is a benefit, and the freedom to review experimental results at any point is a big plus. The comparison of p-values and e-values revealed that e-values are more reliable; p-values carry a greater risk of falsely identifying significant differences when none exist. The safestats R package is a useful tool for implementing these robust tests.

I am convinced of the merits of e-values and look forward to the development of a Python package that supports their implementation! 😄

Related

Sneaky Science: Data Dredging Exposed

How to Compare ML Solutions Effectively?

Simplify Your Machine Learning Projects

The post Is This the Solution to P-Hacking? appeared first on Towards Data Science.

]]>
Sneaky Science: Data Dredging Exposed https://towardsdatascience.com/sneaky-science-data-dredging-exposed-26a445f00e5c/ Wed, 18 Oct 2023 17:34:36 +0000 https://towardsdatascience.com/sneaky-science-data-dredging-exposed-26a445f00e5c/ Delve into the motivations and consequences of P-hacking

The post Sneaky Science: Data Dredging Exposed appeared first on Towards Data Science.

]]>
From pizza to the dark side of research. Image created with Dall·E 3 by the author.
From pizza to the dark side of research. Image created with Dall·E 3 by the author.

Delve into the motivations and consequences of p-hacking

A recent New Yorker headline reads, "They Studied Dishonesty. Was Their Work a Lie?". What’s the story behind it? Behavioral economist Dan Ariely and behavioral scientist Francesca Gino, both acclaimed in their fields, are under scrutiny for alleged research misconduct. To put it bluntly, they’re accused of fabricating data to achieve statistically significant results.

Sadly, such instances are not rare. Scientific research has seen its share of fraud. The practice of p-hacking – e.g. manipulating data, halting experiments once a significant p-value is achieved, or only reporting significant findings – has long been a concern. In this article, we will reflect on why some researchers might be tempted to tweak their findings. We will show the consequences and explain what you can do to prevent p-hacking in your own experiments.

But before we get into the scandals and secrets, let’s start with the basics – a crash course in Hypothesis Testing 101. This knowledge will be helpful as we navigate the world of p-hacking.


Hypothesis Testing 101

Let’s recap the key concepts you need to know to fully grasp the post. If you are familiar with hypothesis testing, including the p-value, type I/II errors, and the significance level you can skip this part.

The Best Pizza Test

Let’s travel to Naples, the famous Italian city known for its pizza. Two pizzerias, Port’Alba and Michele’s, claim they make the best pizza in the world. You’re a curious food critic, determined to find out which pizzeria truly deserves this title. To find out, you decide to host "The Best Pizza Test" (which is essentially just an hypothesis test).

Your investigation starts with two hypotheses:

  • Null Hypothesis (H0): There is no difference in the taste of Port’Alba and Michele’s pizzas; any difference observed is due to chance.
  • Alternative Hypothesis (H1): There is a significant difference in the taste of Port’Alba and Michele’s pizzas, indicating that one is better than the other.

The test starts. You gather a group of participants and organize a blind taste test. Each participant is served two slices of pizza, one from Port’Alba and one from Michele’s, without knowing which is which. The participants will rate both slices (0–10).

You set a strict alpha level (significance level) **** of 0.05. This means you’re willing to tolerate a 5% chance of making a Type I error, which in this context would be falsely claiming that one pizzeria’s pizza is better when it isn’t.

After collecting and analyzing the data, you find that participants overwhelmingly preferred Michele’s pizza. This is what the score distributions look like:

Image by author.
Image by author.

There are two risks in hypothesis testing:

  • Type I error: **** There’s a small chance (5%, equal to the significance level alpha) that you might be making a mistake by concluding that Michele’s pizza is better when, in reality, there is no significant difference. You don’t want to unfairly discredit Port’Alba pizzeria.
  • Type II error: On the flip side, there is the Type II error. What if, in reality, Michele’s pizza is better, but your test failed to detect that? You’d feel like you missed out on the best pizza!

In a matrix (you can compare it with a confusion matrix):

Type I and Type II error visualized. Image by author.
Type I and Type II error visualized. Image by author.

To be sure of your findings, you calculate the p-value. It turns out to be a tiny number, far less than 0.05. This means the probability of getting such extreme results by chance, assuming the null hypothesis is true, is exceedingly low. We have a winner! Example of calculating the p-value with a two-sample t-test:

import numpy as np
from scipy import stats

# Sample data for taste scores (out of 10) for the pizzerias
np.random.seed(42)
portalba_scores = np.random.normal(7.5, 1.5, 50)  
michele_scores = np.random.normal(8.5, 1.5, 50)
michele_scores = [round(min(score, 10), 1) for score in michele_scores]
portalba_scores = [round(min(score, 10), 1) for score in portalba_scores] 

# Perform a two-sample t-test
t_stat, p_value = stats.ttest_ind(portalba_scores, michele_scores)

# Set the significance level (alpha)
alpha = 0.05

# Compare the p-value to alpha to make a decision
if p_value < alpha:
    print("We reject the null hypothesis: {} < {}".format(round(p_value, 7), alpha))
    print("There is a significant difference in taste between Port'Alba's and Michele's pizzas.")
else:
    print("We fail to reject the null hypothesis: {} >= {}".format(round(p_value, 7), alpha))
    print("There is no significant difference in taste between Port'Alba's and Michele's pizzas.")

Output:

We reject the null hypothesis: 3.1e-06 < 0.05
There is a significant difference in taste between Port'Alba's and Michele's pizzas.

Now that you are familiar with the key concepts (I hope you are not hungry after the pizza story), let’s continue with the sneaky part of this post: p-hacking.

Image created with Dall·E 3 by the author.
Image created with Dall·E 3 by the author.

P-Hacking

In the world of academia, journals serve as the gatekeepers of knowledge. These gatekeepers have a preference for studies with significant results. The pressure to secure a publication spot can be immense. This bias can subtly encourage researchers to engage in p-hacking to ensure their work sees the light of day. Unfortunately, this practice perpetuates an overrepresentation of positive findings in the scientific literature.

P-hacking (aka data dredging or data snooping) is defined as:

Manipulation of statistical analyses or experimental design in order to achieve a desired result or to obtain a statistically significant p-value.

So, in research, it’s possible to manipulate your analysis and data to obtain interesting results that you can publish. You keep investigating, analyzing, and sometimes even modifying your data until the p-value is significant. P-hacking is often driven by the desire for recognition, publication, and the allure of significant findings. Researchers may inadvertently or intentionally fall into the p-hacking trap, tempted by the promise of quick acclaim and the pressures of a competitive academic environment.

One easy way of falling into this trap is by extensive exploration of the data. Researchers can find themselves in a situation where the dataset is rich with variables and subgroups waiting to be explored. The temptation to test numerous combinations can be strong. Each variable, each subgroup presents the possibility of finding a significant result. The trap here is cherry-picking – highlighting only those variables and subgroups that support the desired outcome while ignoring the many comparisons made. This can result in misleading and unrepresentative findings.

Another trap is peeking at the p-values before the experiment has officially ended. This might seem innocent, but it isn’t. The pitfall here is the temptation to stop data collection prematurely if that coveted p-value threshold is reached. This can lead to biased and unreliable findings, as the sample might not be representative. And even with a representative sample, it is manipulation of the data by stopping at a moment the results give a significant p-value.

Extensive data exploration and peeking at p-values can happen unintentionally easily. The following example cannot: In the pursuit of significance, a researcher may encounter moments when the initial outcome measure doesn’t yield the desired result. The desire to adjust the primary outcome, subtly shifting the goalposts, can be enticing.

In summary; How do you NOT commit p-hacking? Here are five tips and guidelines:

  • Avoid exploring your data extensively to find significant results by chance. If you are doing multiple comparisons, correct the significance level with the Bonferroni correction (divide the value of alpha by the number of experiments).
  • Don’t selectively report only the analyses that yield significant p-values. Keep it transparent and also report analyses that did not yield significant results. This also holds true for variables: Don’t test multiple variables and report only those that appear significant. In this case you are cherry-picking variables. Even better, you can consider pre-registering your research to declare your hypotheses and analysis plans in advance.
  • HARKing (Hypothesizing After Results are Known): Refrain from forming hypotheses based on the results you’ve obtained; hypotheses should be defined before data analysis.
  • Avoid stopping data collection or analysis when you achieve a significant result. In line with this: Don’t peek at the results before the test has ended. Determine the sample size beforehand. And yes, even looking at the data is wrong!
  • Ensure that you meet the assumptions of your statistical tests. An example is normality of data. Make sure the data is normally distributed if you use a test that assumes this (you can test this with the Shapiro-Wilk test).
Image created with Dall·E 3 by the author.
Image created with Dall·E 3 by the author.

True (Horror) Stories

The consequences of p-hacking are severe, as it leads to false positives (null hypothesis rejected when they are actually true). Research is build upon previous research and every false positive is misinformation for subsequent research and decision-making. Also, it wastes time, money, and resources on pursuing research avenues that do not truly exist. And maybe the most important consequence: It erodes the trust and integrity of scientific research.

Now, it’s time for the scandals… There are many examples of Data Manipulation, data dredging and p-hacking. Let’s take a look at some true stories.

Diederik Stapel’s Fabricated Research

Diederik Stapel was a prominent Dutch social psychologist who, in 2011, was found to have fabricated data in numerous published studies. He manipulated and falsified data to support his hypotheses, leading to a big scandal in the field of psychology.

Before he was caught, he published numerous fake studies (58 papers were retracted, according to Retraction Watch). For example, Stapel faked data that proved that people thinking about meat makes them less social (this wasn’t published). Something that was actually published is a study about how a messy environment promotes discrimination. Of course, this study was retracted when his deception came true. Still you can find some of his retracted papers online.

The Boldt Case

A case that is way more shocking than the previous one, is the case about Joachim Boldt. Boldt was once regarded as a prominent figure in the field of medicinal colloids and had been a strong proponent of using colloidal hydroxyethyl starch (HES) to raise blood pressure during surgical procedures. However, a meta-analysis that deliberately excluded the unreliable data attributed to Boldt unveiled a different story. This analysis revealed that the intravenous administration of hydroxyethyl starch is linked to a significantly higher risk of mortality and acute kidney injury when compared to alternative resuscitation solutions. A headline in the Telegraph:

Millions of patients have been treated with controversial drugs on the basis of "fraudulent research" by one of the world’s leading anesthetists.

As you would expect, the fallout from these revelations was severe. Boldt lost his professorship, and he became the subject of a criminal investigation, facing allegations of falsifying up to 90 research studies.

What made him publish these false results? Why would anyone risk lives for something like a publication or promotion? Unfortunately, in Boldt’s case, we don’t know. While he may have provided explanations or statements during investigations or legal proceedings, there is no comprehensive and widely recognized account of his motivations. Here you can find an article about the case with timeline included.

Reproducibility Crisis

There are many more horror stories you can find on the web. In general, reproducibility in research is important. Why? It allows for the verification and validation of research findings, ensuring that they are not simply due to chance, errors, or biases. But reproducing results can be challenging. There can be variations in experimental conditions, differences in equipment, or subtle errors in data collection and analysis. (Or someone committed p-hacking, which makes it impossible to reproduce the study.)

Some researchers tried to show why reproducibility is an issue. A fun example is a study about chocolate, and how it affects weight loss. Despite intentional flaws in methodology, it was widely published and shared. This incident showcased the pitfalls of weak scientific rigor and sensationalized media coverage.

But also when trying to reproduce existing research, challenges arise. In 2011, Bayer’s researchers revealed that they could only replicate around a quarter (!) of published preclinical studies. This concerning discovery made many question the reliability of some early-stage drug research, emphasizing the pressing need for good validation processes. Similarly, the field of psychology faced its own issues. To address this, the "Many Labs" project was initiated. This global endeavor had laboratories worldwide trying to reproduce the same psychological study, often revealing notable variations in outcomes. Such findings stressed the essential role of collaborative efforts in multiple labs to ensure research findings are genuine.

Unfortunately, this is not true. Image created with Dall·E 3 by the author.
Unfortunately, this is not true. Image created with Dall·E 3 by the author.

Conclusion

This post aimed to open your eyes to the subtle yet big influence of p-hacking on scientific literature. It’s a reminder that not every research article can be taken at face value. As curious and critical thinkers, it’s essential to approach scientific findings with a discerning eye. Don’t blindly trust every publication; instead, seek multiple sources that reinforce the same statement. True knowledge often emerges when different studies converge on a shared truth.

But what if there were a solution to one of academia’s most pressing challenges? Join me in my upcoming blog post as I unveil a potential antidote to the issue of p-hacking. Stay tuned!

Related

Embracing the Unknown: Lessons from Chaos Theory for Data Scientists

Ethical Considerations In Machine Learning Projects

The post Sneaky Science: Data Dredging Exposed appeared first on Towards Data Science.

]]>
Harnessing AI for a Better World https://towardsdatascience.com/harnessing-ai-for-a-better-world-e3357cc73b09/ Fri, 06 Oct 2023 17:34:36 +0000 https://towardsdatascience.com/harnessing-ai-for-a-better-world-e3357cc73b09/ Harnessing AI for a Better World There are many examples of using AI in the wrong way, as highlighted in the thought-provoking book Weapons of Math Destruction. Also, the risks of AI should not be underestimated. AI ethics and governance have become pressing concerns in our rapidly evolving technological landscape, and numerous companies now maintain […]

The post Harnessing AI for a Better World appeared first on Towards Data Science.

]]>

Harnessing AI for a Better World

There are many examples of using AI in the wrong way, as highlighted in the thought-provoking book Weapons of Math Destruction. Also, the risks of AI should not be underestimated. AI ethics and governance have become pressing concerns in our rapidly evolving technological landscape, and numerous companies now maintain dedicated departments focused solely on addressing these issues. However, amid these challenges, there are also inspiring stories that demonstrate how mathematics and AI can be harnessed for the greater good.

This post is written to inspire you and to show how AI can help addressing complex global challenges, from climate change to human rights abuse. The selection of initiatives and companies managed to achieve remarkable results on a large scale and I do believe they deserve a spotlight. It’s only a selection, because there are many more examples available. If you come across something I’ve missed, please don’t hesitate to comment on this post!

One additional motivation for sharing this post is World Animal Day, celebrated on October 4th. 🐈⬛


Refugee Resettlement

Mathematical optimization is typically applied to maximize profits or minimize costs in various scenarios. However, there are exceptional cases where it serves a noble purpose, as demonstrated here.

When refugees seek shelter in a host country, the initial placement of a refugee family into a suitable home can profoundly affect their long-term prospects in terms of employment, education, and overall well-being. Many refugees are vulnerable and it’s important to place them in a safe environment where they can start building a new live. Similarly, for foster homes, it’s essential to align their preferences with the refugees they welcome. For instance, if an elderly woman desires to care for refugees but ends up hosting a family of five men, the situation can be overwhelming. By optimizing the matchmaking process, it’s possible to significantly improve placement outcomes.

Annie MOORE is a software system developed in 2018 that focuses on this problem. The software is used by HIAS, a resettlement agency based in the United States. Annie uses machine learning and mathematical optimization to recommend matches between foster homes and refugees. This software contributed to an impressive increase in employment outcomes for refugees resettled by HIAS, with figures ranging from 22% to 38%.

Taking this innovative approach a step further, a new platform called RUTH (Refugees Uniting Through HIAS) enhances the resettlement process. RUTH incorporates the preferences of both refugees and host families, making relocation faster and more transparent. Dr. Trapp, Associate Professor of Operations and Industrial Engineering at Worcester Polytechnic Institute, highlights the significance of RUTH by stating:

"This is the first time ever that preferences of refugees and priorities of hosts have been systematically used in the resettlement process."

RUTH helped to place refugees from Ukraine in the US during the Russian invasion.

Safeguarding Wildlife

AI is playing a vital role in the protection and conservation of endangered species and wildlife. Leveraging advanced image recognition and predictive modeling, AI-powered systems have revolutionized the way we monitor animal behavior, track migration patterns, and combat poaching threats (yes, poachers still exist). These technologies empower conservationists and law enforcement agencies to respond promptly, safeguarding the world’s biodiversity and preserving our natural heritage.

There are many remarkable examples that can be found on the web. Let’s go through some of them.

Saving species from extinction is something that computer vision and other machine learning techniques can help with. A possibility to keep track of species is to count animals in national parks. Appsilon developed a tool called Mbaza AI. This tool is basically an image classification tool specialized in wildlife. An endangered species in Australia is the koala, because of bushfires and animal attacks. This initiative uses AI to locate and rescue surviving koalas. Collaborating with the Snow Leopard Trust, Microsoft AI is engaged in the detection and recognition of snow leopards, aiding in their conservation and protection. Google AI uses the songs of whales to locate and protect them.

Unfortunately, worldwide there are still poachers hunting down animals. But luckily, AI is making their job way more difficult. Hack the Planet has an initiative that can alert rangers if a poacher is on the road. This is great, because the ranger can go after the poacher directly, instead of watching camera footage constantly. Another project monitors every boat that goes in and out of a park in Zambia, to prevent illegal fishing. Other animals that are monitored with the help of AI are grizzly bears, elephants and penguins.

Combatting Climate Change

Climate change is a big issue. We have to deal with the negative effects of climate change, like rising sea levels, extreme weather events, heightened health risks, food scarcity and population displacement.

The battle against climate change requires data-driven insights and sustainable solutions. AI excels in processing vast datasets and identifying patterns that inform climate policies and resource management. From optimizing renewable energy production to predicting extreme weather events, AI contributes to mitigating the effects of climate change and promoting a more sustainable future for our planet.

Accurate weather forecasting is vital for protecting lives and property. LEAP (Learning the Earth with AI and Physics) is a science technology center dedicated to improve short-term climate forecasts. They achieve this by blending traditional climate science with advanced machine learning. LEAP combines the strengths of both methods, and builds next generation AI models.

As you might know, carbon emissions have an impact on the planet. Many initiatives show that it’s possible to reduce them. A company that uses AI to reduce the carbon footprint is Eugenie.ai, that helps manufacturers to decarbonize their operations. Mortar IO discovers ways to decarbonize existing buildings, their vision is to be the data infrastructure for decarbonizing real estate. The carbon footprint of buildings is equal to 39% of global energy related carbon emissions, so this is quite a good initiative.

An interesting question that might come to mind: What about the emission footprint of AI itself? AI is quite a big consumer (visual in the first chart of this post), and it is possible to reduce the emission drastically, as HuggingFace proved with the BLOOM model. New technologies can help reducing the carbon footprint, also from AI. Maybe the best way to lower the emission is to work on smaller models that perform as good as the bigger ones.

Waste, another topic that arises when you think of climate change. The consumption behavior of people has a huge impact on methane emissions. A cool initiative that handles waste from the fashion industry is Refiberd. They use AI to sort textiles for recycling applications.

Helping Farmers in Developing Nations

In developing nations, small-scale farmers often face challenges in maximizing crop yields and managing resources efficiently. AI-driven applications provide valuable support by offering tailored advice on crop selection, irrigation, and pest control. By equipping farmers with these tools, we empower them to improve their livelihoods and contribute to food security in their communities.

An application that is used al over the world to detect diseases of crops and provide solutions is Plantix. It helps maximizing crop yields and uses AI to detect them. Hello Tractor is a company that provides a platform connecting farmers with tractor owners for plowing and other farming activities. They use mathematical optimization and machine learning to optimize tractor allocation and scheduling, making it more affordable for smallholder farmers in developing nations to access mechanized farming. Another concrete example is Apollo Agriculture. This company helps farmers in Kenya and Zambia, not only with farming, but also with financing. The machine learning part of Apollo focuses on credit models to make credit decisions.

There are many more initiatives, like AgriPredict, Taranis and Farmshine (platforms for farmers with weather predictions, crop disease diagnoses and market access).

Exposing Human Rights Abuses and War Crimes

Many human rights organizations fight AI because of the risks of it. Check out recent posts from Amnesty International, Human Rights Watch and Bellingcat. Besides the criticism, can AI play a positive role in this field as well?

Actually, it can: AI can sift through vast amounts of data, including images, videos, and textual information, to identify and document violations. This is done through image and video analysis, NLP, facial recognition and predictive analytics. Also, AI can aggregate and cross-reference data from various sources to build a comprehensive picture of events and human rights violations. This includes collating information from eyewitness accounts, social media posts, and official reports.

Syrian Archive is an organization that uses open-source intelligence, digital forensics, and AI to document human rights violations in Syria. They collect and verify visual evidence, such as photos and videos, to hold perpetrators accountable. Here you can read more on the methods and tools they use.

Another example is Forensic Architecture. This research agency uses spatial and architectural analysis, combined with AI and machine learning, to investigate human rights violations and state violence. They often collaborate with other organizations to provide valuable insights. These examples show different use cases in which machine learning techniques were utilized.

And last but not least, the Centre for Information Resilience is a non-profit social enterprise that is dedicated to exposing human rights abuses and war crimes. Besides, they also counter disinformation and combat harmful online behavior. An example in which they use data science (mostly analysis and visualization) is the Eyes on Russia map. It is a timeline which outlines important events since the Russian invasion of Ukraine.

Hopefully more initiatives like these will arise to use Ai For Good.


Conclusion

While the ethical challenges of AI are undeniable, it is essential to recognize its potential for creating positive change in the world. The examples discussed here demonstrate that, when harnessed responsibly and with a commitment to ethical principles, AI and mathematics can be powerful tools for addressing pressing global issues and making the world a better place.

If you like to volunteer or work for one of the projects and companies, some of them do have career opportunities! You can go to their website and search for it.

Related

Ethical Considerations In Machine Learning Projects

Five ways to combine Mathematical Optimization and Machine Learning

Model-Agnostic Methods for Interpreting any Machine Learning Model

The post Harnessing AI for a Better World appeared first on Towards Data Science.

]]>
Ant Colony Optimization in Action https://towardsdatascience.com/ant-colony-optimization-in-action-6d9106de60af/ Wed, 20 Sep 2023 17:34:37 +0000 https://towardsdatascience.com/ant-colony-optimization-in-action-6d9106de60af/ Solving optimization problems and enhancing results with ACO in Python

The post Ant Colony Optimization in Action appeared first on Towards Data Science.

]]>
A skiing ant. Image created with Dall·E by the author.
A skiing ant. Image created with Dall·E by the author.

Welcome back! In my previous post, I introduced the fundamentals of Ant Colony Optimization (ACO). In this installment, we’ll delve into implementing the ACO algorithm from scratch to tackle two distinct problem types.

The problems we’ll be addressing are the Traveling Salesman Problem (TSP) and the Quadratic Assignment Problem (QAP). Why these two? Well, the TSP is a classic challenge, and ACO happens to be an effective algorithm for finding the most cost-efficient path through a graph. On the other hand, the Quadratic Assignment Problem represents a different class of problems related to optimizing the arrangement of items, and in this post, I aim to demonstrate that ACO can be a valuable tool for solving such assignment-related problems as well. This versatility makes the ACO algorithm applicable to a wide range of problems. Finally, I’ll share some tips for achieving improved solutions more rapidly.


Traveling Salesman Problem

TSP is straightforward to describe but can pose a significant challenge in finding a solution. Here’s the basic definition: you’re tasked with discovering the shortest route that visits all nodes in a graph. This problem falls into the category of NP-hard problems, which implies that if you attempt to explore all possible routes, it can take an impractical amount of time to find the solution. Instead, a more effective approach is to seek a high-quality solution within a reasonable timeframe, and that’s precisely what we’ll accomplish using ACO.

Problem Definition

With the following code, we can create a TSP instance with a given number of nodes:

import itertools
import math
import random
from typing import Tuple

import networkx as nx
import networkx.algorithms.shortest_paths.dense as nxalg

class TSP:
    """
    Creates a TSP problem with a certain number of nodes
    """
    def __init__(self, nodes: int = 30, dimensions: Tuple[int, int] = (1000, 1000), seed: int = 5):
        if seed:
            random.seed(seed)

        graph = nx.Graph()
        nodes_dict = dict()

        for i in range(nodes):
            nodes_dict[i] = (random.randint(0, dimensions[0]), random.randint(0, dimensions[1]))
            graph.add_node(i)

        for i, j in itertools.permutations(range(nodes), 2):
            graph.add_edge(i, j, weight=self.calculate_distance(nodes_dict[i], nodes_dict[j]))

        self.graph = graph
        self.nodes = nodes_dict
        self.distance_matrix = nxalg.floyd_warshall_numpy(graph)

    @staticmethod
    def calculate_distance(i, j):
        """
        Calculate the Euclidian distance between two nodes
        """
        return int(math.sqrt((i[0] - j[0])**2 + (i[1] - j[1])**2))

The TSP example we will use for demonstrating ACO is the following (default settings):

Visit all nodes and return back to the start node. Image by author.
Visit all nodes and return back to the start node. Image by author.

The optimal solution for this problem (calculated with mixed integer programming) looks like this:

Optimal solution for the traveling salesman problem. Image by author.
Optimal solution for the traveling salesman problem. Image by author.

The distance of this path is 4897.

Solving TSP with ACO

The next step is to solve this problem with ant colony optimization to see how close we can get to the optimal solution. If you are unfamiliar with ACO and you want to learn how the algorithm works, you can read my previous post. Then you can return here to see ACO in action.

The code for ACO:

import datetime
import json
import logging
import random 
import time

import matplotlib.pyplot as plt
import numpy as np

from problem import TSP

class AntColonyOptimization:
    """
    Ant colony optimization algorithm for finding the shortest route in a graph.

    Parameters:
        m = number of ants
        k_max = number of iterations
        alpha = pheromone importance
        beta = distance importance
        rho = pheromone evaporation rate
        Q = pheromone deposit
        tau = pheromone
        eta = distance
    """
    def __init__(self, problem, **kwargs):
        self.graph = problem.graph
        self.nodes = list(problem.nodes)
        self.coordinates = list(problem.nodes.values())
        self.n = len(self.nodes)
        self.distance_matrix = problem.distance_matrix

        self.m = kwargs.get("m", 100)
        self.k_max = kwargs.get("k_max", 50)
        self.alpha = kwargs.get("alpha", 1)
        self.beta = kwargs.get("beta", 5)
        self.rho = kwargs.get("rho", 0.9)
        self.Q = kwargs.get("Q", 1)
        self.time_limit = kwargs.get("time_limit", 5)

        # initialization of tau and eta
        self.tau = np.full(self.distance_matrix.shape, 0.1)
        self.eta = 1 / (self.distance_matrix + 1e-10)

        self.history = []

    def ant_colony_optimization(self):
        """
        Ant colony optimization algorithm
        """
        start_time = time.time()
        x_best, y_best = [], float("inf")
        for _ in range(self.k_max):
            self.edge_attractiveness()
            self.tau *= (1-self.rho)
            for _ in range(self.m):
                x_best, y_best = self.ant_walk(x_best, y_best)
                if time.time() - start_time > self.time_limit:
                    logging.info("Time limit reached. Stopping ACO.")
                    return x_best, y_best
        return x_best, y_best

    def edge_attractiveness(self, plot: bool = False):
        """
        Calculate edge attractiveness
        tau = pheromone
        eta = distance
        alpha = pheromone importance
        beta = distance importance
        """
        self.A = (self.tau ** self.alpha) * (self.eta ** self.beta)

    def ant_walk(self, x_best, y_best, plot: bool = True):
        """
        Ant walk
        """
        x = [0]  # Start at first node
        while len(x) < self.n:
            i = x[-1]
            neighbors = [j for j in range(self.n) if j not in x and self.distance_matrix[i][j] > 0]
            if len(neighbors) == 0:
                return x_best, y_best
            p = [self.A[(i, j)] for j in neighbors]
            sampled_neighbor = random.choices(neighbors, weights=p)[0]
            x.append(sampled_neighbor)
        x.append(0)
        y = self.score(x)
        self.history.append(y)
        for i in range(1, self.n):
            self.tau[(x[i-1], x[i])] += self.Q / y
        if y < y_best:
            logging.info("Better ACO solution found. Score: %.2f", y)
            return x, y
        return x_best, y_best

    def score(self, x):
        """
        Score a solution
        """
        y = 0
        for i in range(len(x) - 1):
            y += self.distance_matrix[x[i]][x[i + 1]]
        return y

Let’s break down the crucial parts of the code step by step:

  • The first step is initialization __init__. We define the problem, which, in this case, is a TSP instance, and optionally provide hyperparameters if needed.
  • The ant_colony_optimization section contains the core execution. Over a specified number of iterations k_max, the algorithm strives to enhance the current best solution. It involves deploying multiple ants m, and each ant traverses the graph.
  • The ant_walk provides the simulation of the journey of a single ant. In the while loop, the ant’s path is constructed by choosing the next edge based on its attractiveness A. The edge attractiveness is computed using the edge_attractiveness method, which takes into account factors like the pheromone matrix tau, alpha, the distance matrix, and beta. The pheromone matrix is updated after each ant’s walk.

To run the algorithm for a problem instance, all you have to do is the following:

problem = TSP()
aco = AntColonyOptimization(problem)
best_solution, best_score = aco.ant_colony_optimization()

Now, how does ACO perform compared to the optimal solution? You can visualize the progress with a GIF that displays the evolving solution. Each image in the GIF shows the current best route, allowing you to observe the improvement over time.

Improvement of the route. Gif by author.
Improvement of the route. Gif by author.

The score of the final solution is equal to 4944, really close to the optimal solution (the gap is less than 1%)! It’s also interesting to take a look at the solving process:

Solution process. Every dot represents one ant walk. Image by author.
Solution process. Every dot represents one ant walk. Image by author.

In this graph, the x-axis represents the ant number, while the y-axis indicates the distance covered during the ant’s journey. The horizontal red line shows the score of the optimal solution, and the red dots symbolize instances where ants have found a new, improved solution. It’s worth noting that it often takes several ants to discover a better solution. However, the last red dot is remarkably close to the optimal solution. There are strategies available to enhance the performance of ACO, which I’ll elaborate on in the following section.


Assignment Problems

TSP is a routing problem, and ACO was originally designed to address routing problems. That’s already a long time ago, and in the meantime people found ways to solve different types of challenges using ACO. One such example worth highlighting is its application to assignment problems.

Assignment problems are problems in which you assign ‘something’ to ‘something else’. One example is the quadratic assignment problem (QAP). Imagine you have a set of locations, and you want to assign a set of facilities to these locations. The goal is to determine the best assignment that minimizes the total cost. The cost of assigning facility f to location l is determined by:

  • The flow between facilities: Facilities have a certain flow or interaction between them. This represents how much "stuff" or "activity" is transferred between facilities.
  • The distance between locations: Each pair of locations has a corresponding distance, which represents the cost or effort required to transport or operate between those locations.

The cost of assigning a particular facility to a specific location is determined by the flow between facilities and the distance between locations. Specifically, the cost for a pair of facilities is computed as the product of their flow and the distance between the locations to which they are assigned.

To find the total cost of a particular assignment, you sum up the pairwise costs for all possible pairs of facilities based on their assignments to locations.

Example of an optimal solution for 5 facilities placed on 5 locations. Flows in blue. Click to enlarge. Image by author.
Example of an optimal solution for 5 facilities placed on 5 locations. Flows in blue. Click to enlarge. Image by author.

How does the problem formulation changes when you want to apply ACO to an assignment problem instead of a routing problem? I’ll leave the coding up to you, but I will provide an intuitive understanding of the difference between a routing and an assignment problem, looking at it from an ant perspective.

The TSP is all about finding the optimal sequence in which to visit various locations. On the other hand, the QAP shifts the focus to deciding where to place items or facilities. In ACO for TSP, the ants learn to favor specific visitation sequences over others. In contrast, when tackling the QAP, the ants lean towards selecting particular facilities for specific locations. In this scenario, the pheromone trails represent how desirable it is to assign a facility to a specific location.

You can imagine it like this: each ant has a list of available places (locations in QAP) where they can allocate items (facilities in QAP). The ants proceed through these steps repeatedly until they’ve determined the optimal arrangement of all items.

In simpler terms, think of it as ants collaboratively determining the most efficient way to allocate items, drawing from the knowledge they’ve acquired about what works best in terms of item-to-location assignments.


Improve Solution Quality

There are several strategies to achieve better solutions in a shorter timeframe. Here are three valuable tips that can significantly impact your results:

Hyperparameter Search

If you’re dealing with multiple problems of the same type, it’s highly advisable to conduct a hyperparameter search. Parameters such as the number of iterations, the quantity of ants, alpha, beta, rho, and Q can exert a substantial influence on algorithm performance.

Let’s examine an example. In the plot below, we test two different values of alpha, while the other parameters stay the same. The line plots depict the moving average over 100 ant runs, while the dots represent the best solution found for a certain run.

The algorithm with the orange settings (alpha = 2) not only discovers a superior solution (indicated by the orange dot) but also accomplishes this more rapidly compared to the algorithm with the blue settings (alpha = 1).

To emphasize the impact of hyperparameter tuning further, consider a TSP problem with 100 nodes. If we conduct a random search (10 iterations) across the hyperparameters and plot the moving average, we observe significant variations in performance.

At the top, the orange and yellow lines yield suboptimal results, while others come remarkably close to the optimal solution. For context, it’s worth noting that the OR tools solver took 5 minutes to find the optimum, whereas the ACO algorithms completed their runs in a maximum of 5 seconds.

Warm Up Procedure

Before setting the ants in motion, you have the option to initiate a warm-up for the pheromone matrix. This preparatory step often reduces the number of iterations needed to find the best solution. In this paper, a warm-up procedure is proposed and its effects are compared.

Exploration and Exploitation

Similar to many algorithms, they can encounter stagnation if they don’t explore sufficiently. To address this, you can use the Max-Min Ant System. MMAS encourages ants to venture into unexplored paths by assigning a high pheromone value to these paths. When stagnation occurs, the trails are reset to this elevated value. An additional advantage of MMAS over the original algorithm is that only the global best tour or the iteration’s best tour are permitted to enhance their trails with pheromone. These adjustments promote a more effective balance between exploration and exploitation.

Photo by Shardar Tarikul Islam on Unsplash
Photo by Shardar Tarikul Islam on Unsplash

Conclusion

Ant colony optimization is a fun algorithm to play around with. It can find high quality solutions to routing and assignment problems, even if the problem size increases. This makes it a powerful algorithm.

To further enhance the algorithm’s performance and uncover superior solutions, consider incorporating essential strategies such as hyperparameter search, a warm-up procedure, and techniques for balancing exploration and exploitation. These adjustments can lead to significant improvements in your results.

If you became curious and want to read more on ACO, I can recommend this work. One of the authors is the one who proposed ACO in 1992. It offers valuable insights and a comprehensive understanding of this remarkable optimization technique.

Related

Optimizing Connections: Mathematical Optimization within Graphs

An Introduction to a Powerful Optimization Technique: Simulated Annealing

Meta-Heuristics Explained: Ant Colony Optimization

The post Ant Colony Optimization in Action appeared first on Towards Data Science.

]]>