Machine Learning | Towards Data Science https://towardsdatascience.com/category/artificial-intelligence/machine-learning/ The world’s leading publication for data science, AI, and ML professionals. Thu, 10 Apr 2025 00:28:34 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Machine Learning | Towards Data Science https://towardsdatascience.com/category/artificial-intelligence/machine-learning/ 32 32 Why CatBoost Works So Well: The Engineering Behind the Magic https://towardsdatascience.com/catboost-inner-workings-and-optimizations/ Thu, 10 Apr 2025 00:28:11 +0000 https://towardsdatascience.com/?p=605702 CatBoost stands out by directly tackling a long-standing challenge in gradient boosting—how to handle categorical variables effectively without causing target leakage. By introducing innovative techniques such as Ordered Target Statistics and Ordered Boosting, and by leveraging the structure of Oblivious Trees, CatBoost efficiently balances robustness and accuracy. These methods ensure that each prediction uses only past data, preventing leakage and resulting in a model that is both fast and reliable for real-world tasks.

The post Why CatBoost Works So Well: The Engineering Behind the Magic appeared first on Towards Data Science.

]]>

Gradient boosting is a cornerstone technique for modeling tabular data due to its speed and simplicity. It delivers great results without any fuss. When you look around you’ll see multiple options like LightGBM, XGBoost, etc. Catboost is one such variant. In this post, we will take a detailed look at this model, explore its inner workings, and understand what makes it a great choice for real-world tasks.

Target Statistic

Table illustrating target encoding for categorical values. It maps vehicle types—Car, Bike, Bus, and Cycle—to numerical target means: 3.9, 1.2, 11.7, and 0.8 respectively. A curved arrow at the bottom indicates the transformation from category to numeric value
Target Encoding Example: the average value of the target variable for a category is used to replace each category. Image by author


Target Encoding Example: the average value of the target variable for a category is used to replace each category


One of the important contributions of the CatBoost paper is a new method of calculating the Target Statistic. What is a Target Statistic? If you have worked with categorical variables before, you’d know that the most rudimentary way to deal with categorical variables is to use one-hot encoding. From experience, you’d also know that this introduces a can of problems like sparsity, curse of dimensionality, memory issues, etc. Especially for categorical variables with high cardinality.

Greedy Target Statistic

To avoid one-hot encoding, we calculate the Target Statistic instead for the categorical variables. This means we calculate the mean of the target variable at each unique value of the categorical variable. So if a categorical variable takes the values — A, B, C then we will calculate the average value of \(\text{y}\) over all these values and replace these values with the average of \(\text{y}\) at each unique value.

That sounds good, right? It does but this approach comes with its problems — namely Target Leakage. To understand this, let’s take an extreme example. Extreme examples are often the easiest way to eke out issues in the approach. Consider the below dataset:

Categorical ColumnTarget Column
A0
B1
C0
D1
E0
Greedy Target Statistic: Compute the mean target value for each unique category


Now let’s write the equation for calculating the Target Statistic:
\[\hat{x}^i_k = \frac{
\sum_{j=1}^{n} 1_{{x^i_j = x^i_k}} \cdot y_j + a p
}{
\sum_{j=1}^{n} 1_{{x^i_j = x^i_k}} + a
}\]

Here \(x^i_j\) is the value of the i-th categorical feature for the j-th sample. So for the k-th sample, we iterate over all samples of \(x^i\), select the ones having the value \(x^i_k\), and take the average value of \(y\) over those samples. Instead of taking a direct average, we take a smoothened average which is what the \(a\) and \(p\) terms are for. The \(a\) parameter is the smoothening parameter and \(p\) is the global mean of \(y\).

If we calculate the Target Statistic using the formula above, we get:

Categorical ColumnTarget ColumnTarget Statistic
A0\(\frac{ap}{1+a}\)
B1\(\frac{1+ap}{1+a}\)
C0\(\frac{ap}{1+a}\)
D1\(\frac{1+ap}{1+a}\)
E0\(\frac{ap}{1+a}\)
Calculation of Greedy Target Statistic with Smoothening


Now if I use this Target Statistic column as my training data, I will get a perfect split at \( threshold = \frac{0.5+ap}{1+a}\). Anything above this value will be classified as 1 and anything below will be classified as 0. I have a perfect classification at this point, so I get 100% accuracy on my training data.

Let’s take a look at the test data. Here, since we are assuming that the feature has all unique values, the Target Statistic becomes—
\[TS = \frac{0+ap}{0+a} = p\]
If \(threshold\) is greater than \(p\), all test data predictions will be \(0\). Conversely, if \(threshold\) is less than \(p\), all test data predictions will be \(1\) leading to poor performance on the test set.

Although we rarely see datasets where values of a categorical variable are all unique, we do see cases of high cardinality. This extreme example shows the pitfalls of using Greedy Target Statistic as an encoding approach.

Leave One Out Target Statistic

So the Greedy TS didn’t work out quite well for us. Let’s try another method— the Leave One Out Target Statistic method. At first glance, this looks promising. But, as it turns out, this too has its problems. Let’s see how with another extreme example. This time let’s assume that our categorical variable \(x^i\) has only one unique value, i.e., all values are the same. Consider the below data:

Categorical ColumnTarget Column
A0
A1
A0
A1
Example data for an extreme case where a categorical feature has just one unique value


If calculate the leave one out target statistic, we get:

Categorical ColumnTarget ColumnTarget Statistic
A0\(\frac{n^+ -y_k + ap}{n+a}\)
A1\(\frac{n^+ -y_k + ap}{n+a}\)
A0\(\frac{n^+ -y_k + ap}{n+a}\)
A1\(\frac{n^+ -y_k + ap}{n+a}\)
Calculation of Leave One Out Target Statistic with Smoothening


Here:
\(n\) is the total samples in the data (in our case this 4)
\(n^+\) is the number of positive samples in the data (in our case this 2)
\(y_k\) is the value of the target column in that row
Substituting the above, we get:

Categorical ColumnTarget ColumnTarget Statistic
A0\(\frac{2 + ap}{4+a}\)
A1\(\frac{1 + ap}{4+a}\)
A0\(\frac{2 + ap}{4+a}\)
A1\(\frac{1 + ap}{4+a}\)
Substituing values of n and n<sup>+</sup>


Now, if I use this Target Statistic column as my training data, I will get a perfect split at \( threshold = \frac{1.5+ap}{4+a}\). Anything above this value will be classified as 0 and anything below will be classified as 1. I have a perfect classification at this point, so I again get 100% accuracy on my training data.

You see the problem, right? My categorical variable which doesn’t have more than a unique value is producing different values for Target Statistic which will perform great on the training data but will fail miserably on the test data.

Ordered Target Statistic

Illustration of ordered learning: CatBoost processes data in a randomly permuted order and predicts each sample using only the earlier samples (Image by Author)
Illustration of ordered learning: CatBoost processes data in a randomly permuted order and predicts each sample using only the earlier samples. Image by author

CatBoost introduces a technique called Ordered Target Statistic to address the issues discussed above. This is the core principle of CatBoost’s handling of categorical variables.

This method, inspired by online learning, uses only past data to make predictions. CatBoost generates a random permutation (random ordering) of the training data(\(\sigma\)). To compute the Target Statistic for a sample at row \(k\), CatBoost uses samples from row \(1\) to \(k-1\). For the test data, it uses the entire train data to compute the statistic.

Additionally, CatBoost generates a new permutation for each tree, rather than reusing the same permutation each time. This reduces the variance that can arise in the early samples.

Ordered Boosting

Diagram illustrating the ordered boosting mechanism in CatBoost. Data points x₁ through xᵢ are shown sequentially, with earlier samples used to compute predictions for later ones. Each xᵢ is associated with a model prediction M, where the prediction for xᵢ is computed using the model trained on previous data points. The equations show how residuals are calculated and how the model is updated: rᵗ(xᵢ, yᵢ) = yᵢ − M⁽ᵗ⁻¹⁾ᵢ⁻¹(xᵢ), and ΔM is learned from samples with order less than or equal to i. Final model update: Mᵢ = Mᵢ + ΔM.
This visualization shows how CatBoost computes residuals and updates the model: for sample xᵢ, the model predicts using only earlier data points. Source

Another important innovation introduced by the CatBoost paper is its use of Ordered Boosting. It builds on similar principles as ordered target statistics, where CatBoost randomly permutes the training data at the start of each tree and makes predictions sequentially.

In traditional boosting methods, when training tree \(t\), the model uses predictions from the previous tree \(t−1\) for all training samples, including the one it is currently predicting. This can lead to target leakage, as the model may indirectly use the label of the current sample during training.

To address this issue, CatBoost uses Ordered Boosting where, for a given sample, it only uses predictions from previous rows in the training data to calculate gradients and build trees. For each row \(i\) in the permutation, CatBoost calculates the output value of a leaf using only the samples before \(i\). The model uses this value to get the prediction for row \(i\). Thus, the model predicts each row without looking at its label.

CatBoost trains each tree using a new random permutation to average the variance in early samples in one permutation.
Let’s say we have 5 data points: A, B, C, D, E. CatBoost creates a random permutation of these points. Suppose the permutation is: σ = [C, A, E, B, D]

StepData Used to TrainData Point Being PredictedNotes
1CNo previous data → use prior
2CAModel trained on C only
3C, AEModel trained on C, A
4C, A, EBModel trained on C, A, E
5C, A, E, BDModel trained on C, A, E, B
Table highlighting how CatBoost uses random permutation to perform training

This avoids using the actual label of the current row to get the prediction thus preventing leakage.

Building a Tree

Each time CatBoost builds a tree, it creates a random permutation of the training data. It calculates the ordered target statistic for all the categorical variables with more than two unique values. For a binary categorical variable, it maps the values to zeros and ones.

CatBoost processes data as if the data is arriving sequentially. It begins with an initial prediction of zero for all instances, meaning the residuals are initially equivalent to the target values.

As training proceeds, CatBoost updates the leaf output for each sample using the residuals of the previous samples that fall into the same leaf. By not using the current sample’s label for prediction, CatBoost effectively prevents data leakage.

Split Candidates

Histogram showing how continuous features can be divided into bins—CatBoost evaluates splits using these binned values instead of raw continuous values
CatBoost bins continuous features to reduce the search space for optimal splits. Each bin edge and split point represents a potential decision threshold. Image by author

At the core of a decision tree lies the task of selecting the optimal feature and threshold for splitting a node. This involves evaluating multiple feature-threshold combinations and selecting the one that gives the best reduction in loss. CatBoost does something similar. It discretizes the continuous variable into bins to simplify the search for the optimal combination. It evaluates each of these feature-bin combinations to determine the best split

CatBoost uses Oblivious Trees, a key difference compared to other trees, where it uses the same split across all nodes at the same depth.

Oblivious Trees

Comparison between Oblivious Trees and Regular Trees. The Oblivious Tree on the left applies the same split condition at each level across all nodes, resulting in a symmetric structure. The Regular Tree on the right applies different conditions at each node, leading to an asymmetric structure with varied splits at different depths
Illustration of ordered learning: CatBoost processes data in a randomly permuted order and predicts each sample using only the earlier samples. Image by author

Unlike standard decision trees, where different nodes can split on different conditions (feature-threshold), Oblivious Trees split across the same conditions across all nodes at the same depth of a tree. At a given depth, all samples are evaluated at the same feature-threshold combination. This symmetry has several implications:

  • Speed and simplicity: since the same condition is applied across all nodes at the same depth, the trees produced are simpler and faster to train
  • Regularization: Since all trees are forced to apply the same condition across the tree at the same depth, there is a regularization effect on the predictions
  • Parallelization: the uniformity of the split condition, makes it easier to parallelize the tree creation and usage of GPU to accelerate training

Conclusion

CatBoost stands out by directly tackling a long-standing challenge: how to handle categorical variables effectively without causing target leakage. Through innovations like Ordered Target Statistics, Ordered Boosting, and the use of Oblivious Trees, it efficiently balances robustness and accuracy.

If you found this deep dive helpful, you might enjoy another deep dive on the differences between Stochastic Gradient Classifer and Logistic Regression

Further Reading

The post Why CatBoost Works So Well: The Engineering Behind the Magic appeared first on Towards Data Science.

]]>
Circuit Tracing: A Step Closer to Understanding Large Language Models https://towardsdatascience.com/circuit-tracing-a-step-closer-to-understanding-large-language-models/ Tue, 08 Apr 2025 18:38:39 +0000 https://towardsdatascience.com/?p=605686 Reverse-engineering large languages models' computation circuit to understand their decision-making processes

The post Circuit Tracing: A Step Closer to Understanding Large Language Models appeared first on Towards Data Science.

]]>
Context

Over the years, Transformer-based large language models (LLMs) have made substantial progress across a wide range of tasks evolving from simple information retrieval systems to sophisticated agents capable of coding, writing, conducting research, and much more. But despite their capabilities, these models are still largely black boxes. Given an input, they accomplish the task but we lack intuitive ways to understand how the task was actually accomplished.

LLMs are designed to predict the statistically best next word/token. But do they only focus on predicting the next token, or plan ahead? For instance, when we ask a model to write a poem, is it generating one word at a time, or is it anticipating rhyme patterns before outputting the word? or when asked about basic reasoning question like what is state capital where city Dallas is located? They often produce results that looks like a chain of reasoning, but did the model actually use that reasoning? We lack visibility into the model’s internal thought process. To understand LLMs, we need to trace their underlying logic.

The study of LLMs internal computation falls under “Mechanistic Interpretability,” which aims to uncover the computational circuit of models. Anthropic is one of the leading AI companies working on interpretability. In March 2025, they published a paper titled “Circuit Tracing: Revealing Computational Graphs in Language Models,” which aims to tackle the problem of circuit tracing.

This post aims to explain the core ideas behind their work and build a foundation for understating circuit tracing in LLMs.

What is a circuit in LLMs?

Before we can define a “circuit” in language models, we first need to look inside the LLM. It’s a Neural Network built on the transformer architecture, so it seems obvious to treat neurons as a basic computational unit and interpret the patterns of their activations across layers as the model’s computation circuit.

However, the “Towards Monosemanticity” paper revealed that tracking neuron activations alone doesn’t provide a clear understanding of why those neurons are activated. This is because individual neurons are often polysemantic they respond to a mix of unrelated concepts.

The paper further showed that neurons are composed of more fundamental units called features, which capture more interpretable information. In fact, a neuron can be seen as a combination of features. So rather than tracing neuron activations, we aim to trace feature activations the actual units of meaning driving the model’s outputs.

With that, we can define a circuit as a sequence of feature activations and connections used by the model to transform a given input into an output.

Now that we know what we’re looking for, let’s dive into the technical setup.

Technical Setup

We’ve established that we need to trace feature activations rather than neuron activations. To enable this, we need to convert the neurons of the existing LLM models into features, i.e. build a replacement model that represents computations in terms of features.

Before diving into how this replacement model is constructed, let’s briefly review the architecture of Transformer-based large language models.

The following diagram illustrates how transformer-based language models operate. The idea is to convert the input into tokens using embeddings. These tokens are passed to the attention block, which calculates the relationships between tokens. Then, each token is passed to the multi-layer perceptron (MLP) block, which further refines the token using a non-linear activation and linear transformations. This process is repeated across many layers before the model generates the final output.

Image by Author

Now that we have laid out the structure of transformer based LLM, let’s looks at what transcoders are. The authors have used a “Transcoder” to develop the replacement model.

Transcoders

A transcoder is a neural network (generally with a much higher dimension than LLM’s dimension) in itself designed to replace the MLP block in a transformer model with a more interpretable, functionally equivalent component (feature).

Image by Author

It processes tokens from the attention block in three stages: encoding, sparse activation, and decoding. Effectively, it scales the input to a higher-dimensional space, applies activation to force the model to activate only sparse features, and then compresses the output back to the original dimension in the decoding stage.

Image by Author

With a basic understanding of transformer-based LLMs and transcoder, let’s look at how a transcoder is used to build a replacement model.

Construct a replacement model

As mentioned earlier, a transformer block typically consists of two main components: an attention block and an MLP block (feedforward network). To build a replacement model, the MLP block in the original transformer model is replaced with a transcoder. This integration is seamless because the transcoder is trained to mimic the output of the original MLP, while also exposing its internal computations through sparse and modular features.

While standard transcoders are trained to imitate the MLP behavior within a single transformer layer, the authors of the paper used a cross layer transcoder (CLT), which captures the combined effects of multiple transcoder blocks across several layers. This is important because it allows us to track if a feature is spread across multiple layers, which is needed for circuit tracing.

The below image illustrates how the cross-layer transcoders (CLT) setup is used in building a replacement model. The Transcoder output at layer 1 contributes to constructing the MLP-equivalent output in all the upper layers until the end.

Image by Author

Side Note: the following image is from the paper and shows how a replacement model is constructed. it replaces the neuron of the original model with features.

Image from https://transformer-circuits.pub/2025/attribution-graphs/methods.html#graphs-constructing

Now that we understand the architecture of the replacement model, let’s look at how the interpretable presentation is built on the replacement model’s computational path.

Interpretable presentation of model’s computation: Attribution graph

To build the interpretable representation of the model’s computational path, we start from the model’s output feature and trace backward through the feature network to uncover which earlier feature contributed to it. This is done using the backward Jacobian, which tells how much a feature in the previous layer contributed to the current feature activation, and is applied recursively until we reach the input. Each feature is considered as a node and each influence as an edge. This process can lead to a complex graph with millions of edges and nodes, hence pruning is done to keep the graph compact and manually interpretable.

The authors refer to this computational graph as an attribution graph and have also developed a tool to inspect it. This forms the core contribution of the paper.

The image below illustrate a sample attribution graph.

Image from https://transformer-circuits.pub/2025/attribution-graphs/methods.html#graphs

Now, with all this understanding, we can go to feature interpretability.

Feature interpretability using an attribution graph

The researchers used attribution graphs on Anthropic’s Claude 3.5 Haiku model to study how it behaves across different tasks. In the case of poem generation, they discovered that the model doesn’t just generate the next word. It engages in a form of planning, both forward and backward. Before generating a line, the model identifies several possible rhyming or semantically appropriate words to end with, then works backward to craft a line that naturally leads to that target. Surprisingly, the model appears to hold multiple candidate end words in mind simultaneously, and it can restructure the entire sentence based on which one it ultimately chooses.

This technique offers a clear, mechanistic view of how language models generate structured, creative text. This is a significant milestone for the AI community. As we develop increasingly powerful models, the ability to trace and understand their internal planning and execution will be essential for ensuring alignment, safety, and trust in AI systems.

Limitations of the current approach

Attribution graphs offer a way to trace model behavior for a single input, but they don’t yet provide a reliable method for understanding global circuits or the consistent mechanisms a model uses across many examples. This analysis relies on replacing MLP computations with transcoders, but it is still unclear whether these transcoders truly replicate the original mechanisms or simply approximate the outputs. Additionally, the current approach highlights only active features, but inactive or inhibitory ones can be just as important for understanding the model’s behavior.

Conclusion

Circuit tracing via attribution graph is an early but important step toward understanding how language models work internally. While this approach still has a long way to go, the introduction of circuit tracing marks a major milestone on the path to true interpretability.

References:

The post Circuit Tracing: A Step Closer to Understanding Large Language Models appeared first on Towards Data Science.

]]>
Kernel Case Study: Flash Attention https://towardsdatascience.com/kernel-case-study-flash-attention/ Thu, 03 Apr 2025 18:53:44 +0000 https://towardsdatascience.com/?p=605403 Understanding all versions of flash attention through a triton implementation

The post Kernel Case Study: Flash Attention appeared first on Towards Data Science.

]]>
The attention mechanism is at the core of modern day transformers. But scaling the context window of these transformers was a major challenge, and it still is even though we are in the era of a million tokens + context window (Qwen 2.5 [1]). There are both considerable compute and memory bound complexities in these models when we scale the context window (A naive Attention Mechanism scales quadratically in both compute and memory requirements). Revisiting Flash Attention lets us understand the complexities of optimizing the underlying operations on GPUs and more importantly gives us a better grip on thinking what’s next.

Let’s quickly revisit a naive attention algorithm to see what’s going on.

Attention Algorithm. Image by Author

As you can see if we are not being careful then we will end up materializing a full NxM attention matrix into the GPU HBM. Meaning the memory requirement will go up quadratically to increasing context length.

If you wanna learn more about the GPU memory hierarchy and its differences, my previous post on Triton is a good starting point. This would also be handy as we go along in this post when we get to implementing the Flash Attention kernel in triton. The flash attention paper also has some really good introduction to this.

Additionally, when we look at the steps involved in executing this algorithm and its pattern of accessing the slow HBM, (which as explained later in the post could be a major bottleneck as well) we notice a few things:

  1. We have Q, K and V in the HBM initially
  2. We need to access Q and K initially from the HBM to compute the dot product
  3. We write the output scores back to the HBM
  4. We access it again to execute the softmax, and optionally for Causal attention, like in the case of LLMs, we will have to mask this output before the softmax. The resulting full attention matrix is written again into the HBM
  5. We access the HBM again to execute the final dot product, to get both the attention weights and the Value matrix to write the output back to the slow GPU memory

I think you get the point. We could smartly read and write from the HBM to avoid redundant operations, to make some potential gains. This is exactly the primary motivation for the original Flash Attention algorithm.

Flash Attention initially came out in 2022 [2], and then a year later came out with some much needed improvements in 2023 as Flash Attention v2 [3] and again in 2024 with additional improvements for Nvidia Hopper and Blackwell GPUs [4] as Flash Attention v3 [5]. The original attention paper identified that the attention operation is still limited by memory bandwidth rather than compute. (In the past, there have been attempts to reduce the computation complexity of Attention from O(N**2) to O(NlogN) and lower through approximate algorithms)

Flash attention proposed a fused kernel which does all of the above attention operations in one go, block-wise, to get the final attention output without ever having to realize the full N**2 attention matrix in memory, making the algorithm significantly faster. The term `fused` simply means we combine multiple operations in the GPU SRAM before invoking the much slower journey across the slower GPU memory, making the algorithm performant. All the while providing the exact attention output without any approximations.

This lecture, from Stanford CS139, demonstrates brilliantly how we can think of the impact of a well thought out memory access pattern can have on an algorithm. I highly recommend you check this one out if you haven’t already.

Before we start diving into flash attention (it’s getting tedious to type this over and over so let’s agree to call it FA, shall we?) in triton there is something else that I wanted to get out of the way.

Numerical Stability in exponents

Let’s take the example of FP32 numbers. float32 (standard 32-bit float) uses 1 sign bit, 8 exponent bits, and 23 mantissa bits [6]. The largest finite base for the exponent in float32 is 2127≈1.7×1038. Which implies when we look at exponents, e88 ≈ 1.65×1038, anything close to 88 (although in reality would be much lower to keep it safe) and we are in trouble as we could easily overflow. Here’s a very interesting chat with OpenAI o1 shared by folks at AllenAI in their OpenInstruct repo. This although is talking about stabilizing KL Divergence calculations in the setting of RLHF/RL, the ideas translate exactly to exponents as well. So to deal with the softmax situation in attention what we do is the following:

Softmax with rescaling. Image by Author

TRICK : Let’s also observe the following, if you do this:

Rescaling Trick. Image by Author

then you can rescale/readjust values without affecting the final softmax value. This is really useful when you have an initial estimate for the maximum value, but that might change when we encounter a new set of values. I know I know, stay with me and let me explain.

Setting the scene

Let’s take a small detour into matrix multiplication.

Blocked Matrix Multiplication. Image by Author

This shows a toy example of a blocked matrix multiplication except we have blocks only on the rows of A (green) and columns of B (Orange? Beige?). As you can see above the output O1, O2, O3 and O4 are complete (those positions need no more calculations). We just need to fill in the remaining columns in the initial rows by using the remaining columns of B. Like below:

Next set of block fill the remaining spaces up. Image by Author

So we can fill these places in the output with a block of columns from B and a block of rows from A at a time.

Connecting the dots

When I introduced FA, I said that we never have to compute the full attention matrix and store the whole thing. So here’s what we do:

  1. Compute a block of the attention matrix using a block of rows from Q and a block of columns from K. Once you get the partial attention matrix compute a few statistics and keep it in the memory.
Computing block attention scores S_b, and computing the row-wise maximums. Image by Author

I have greyed O5 to O12 because we don’t know those values yet, as they need to come from the subsequent blocks. We then transform Sb like below:

Keeping a track of the current row-sum and row-maxes. Image by Author
Exponents with the scaling trick. Image by Author

Now you have the setup for a partial softmax

Partial Softmax, as the denominator is still a partial sum. Image by Author

But:

  1. What if the true maximum is in the Oi’s that are yet to come?
  2. The sum is still local, so we need to update this every time we see new Pi’s. We know how to keep track of a sum, but what about rebasing it to the true maximum?

Recall the trick above. All that we have to do is to keep a track of the maximum values we encounter for each row, and iteratively update as you see new maximums from the remaining blocks of columns from K for the same set of rows from Q.

Two consecutive blocks and its row max manipulations. Image by Author
Updating the estimate of our current sum with rescaling

We still do not want to write our partial softmax matrix into HBM. We keep it for the next step.

The final dot product

The last step in our attention computation is our dot product with V. To start we would have initialized a matrix full of 0’s in our HBM as our output of shape NxD. Where N is the number of Queries as above. We use the same block size for V as we had for K except we can apply it row wise like below (The subscripts just denote that this is only a block and not the full matrix)

A single block of attention scores creating a partial output. Image by Author
Whereas the full output would require the sum of all these dot products. Some of which will be filled in by the blocks to come. Image by Author

Notice how we need the attention scores from all the blocks to get the final product. But if we calculate the local score and `accumulate` it like how we did to get the actual Ls we can form the full output at the end of processing all the blocks of columns (Kb) for a given row block (Qb).

Putting it all together

Let’s put all these ideas together to form the final algorithm

Flash Attention V1 Algorithm. Source: Tri Dao et.al [2]

To understand the notation, _ij implies that it is the local values for a given block of columns and rows and _i implies it’s for the global output rows and Query blocks. The only part we haven’t explained so far is the final update to Oi. That’s where we use all the ideas from above to get the right scaling.

The whole code is available as a gist here.

Let’s see what these initializations look like in torch:

def flash_attn_v1(Q, K, V, Br, Bc):
  """Flash Attention V1"""
  B, N, D = Q.shape
  M = K.shape[1]
  Nr = int(np.ceil(N/Br))
  Nc = int(np.ceil(N/Bc))
  
  Q = Q.to('cuda')
  K = K.to('cuda')
  V = V.to('cuda')
  
  batch_stride = Q.stride(0)
  
  O = torch.zeros_like(Q).to('cuda')
  lis = torch.zeros((B, Nr, int(Br)), dtype=torch.float32).to('cuda')
  mis = torch.ones((B, Nr, int(Br)), dtype=torch.float32).to('cuda')*-torch.inf
  
  grid = (B, )
  flash_attn_v1_kernel[grid](
      Q, K, V,
      N, M, D,
      Br, Bc,
      Nr, Nc,
      batch_stride,
      Q.stride(1),
      K.stride(1),
      V.stride(1),
      lis, mis,
      O,
      O.stride(1),
  )
  return O

If you are unsure about the launch grid, checkout my introduction to Triton

Take a closer look at how we initialized our Ls and Ms. We are keeping one for each row block of Output/Query, each of size Br. There are Nr such blocks in total.

In the example above I was simply using Br = 2 and Bc = 2. But in the above code the initialization is based on the device capacity. I have included the calculation for a T4 GPU. For any other GPU, we need to get the SRAM capacity and adjust these numbers accordingly. Now for the actual kernel implementation:

# Flash Attention V1
import triton
import triton.language as tl
import torch
import numpy as np
import pdb

@triton.jit
def flash_attn_v1_kernel(
    Q, K, V,
    N: tl.constexpr, M: tl.constexpr, D: tl.constexpr,
    Br: tl.constexpr,
    Bc: tl.constexpr,
    Nr: tl.constexpr,
    Nc: tl.constexpr,
    batch_stride: tl.constexpr,
    q_rstride: tl.constexpr,
    k_rstride: tl.constexpr, 
    v_rstride: tl.constexpr,
    lis, mis,
    O,
    o_rstride: tl.constexpr):
    
    """Flash Attention V1 kernel"""
    
    pid = tl.program_id(0)
    

    for j in range(Nc):
        k_offset = ((tl.arange(0, Bc) + j*Bc) * k_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * M * D
        # Using k_rstride and v_rstride as we are looking at the entire row at once, for each k v block 
        v_offset = ((tl.arange(0, Bc) + j*Bc) * v_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * M * D
        k_mask = k_offset < (pid + 1) * M*D
        v_mask = v_offset < (pid + 1) * M*D
        k_load = tl.load(K + k_offset, mask=k_mask, other=0)
        v_load = tl.load(V + v_offset, mask=v_mask, other=0)
        for i in range(Nr):
            q_offset = ((tl.arange(0, Br) + i*Br) * q_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * N * D
            q_mask = q_offset < (pid + 1) * N*D
            q_load = tl.load(Q + q_offset, mask=q_mask, other=0)
            # Compute attention
            s_ij = tl.dot(q_load, tl.trans(k_load))
            m_ij = tl.max(s_ij, axis=1, keep_dims=True)
            p_ij = tl.exp(s_ij - m_ij)
            l_ij = tl.sum(p_ij, axis=1, keep_dims=True)
            
            ml_offset = tl.arange(0, Br) + Br * i + pid * Nr * Br
            m = tl.load(mis + ml_offset)[:, None]
            l = tl.load(lis + ml_offset)[:, None]

            m_new = tl.where(m < m_ij, m_ij, m)

            l_new = tl.exp(m - m_new) * l + tl.exp(m_ij - m_new) * l_ij

            o_ij = tl.dot(p_ij, v_load)

            output_offset = ((tl.arange(0, Br) + i*Br) * o_rstride)[:, None] + (tl.arange(0, D))[None, :] + pid * N * D
            output_mask = output_offset < (pid + 1) * N*D
            o_current = tl.load(O + output_offset, mask=output_mask)

            o_new = (1/l_new) * (l * tl.exp(m - m_new) * o_current + tl.exp(m_ij - m_new) * o_ij)

            tl.store(O + output_offset, o_new, mask=output_mask)
            tl.store(mis + ml_offset, tl.reshape(m_new, (Br,)))
            tl.store(lis + ml_offset, tl.reshape(l_new, (Br,)))

Let’s understand whats happening here:

  1. Create 1 kernel for each NxD matrix in the batch. In reality we would have one more dimension to parallelize across, the head dimension. But for understanding the implementation I think this would suffice.
  2. In each kernel we do the following:
    1. For each block of columns in K and V we load up the relevant part of the matrix (Bc x D) into the GPU SRAM (Current total SRAM usage = 2BcD). This stays in the SRAM till we are done with all the row blocks
    2. For each row block of Q, we load the block onto SRAM as well (Current total SRAM Usage = 2BcD + BrD)
    3. On chip we compute the dot product (sij), compute the local row-maxes (mij), the exp (pij), and the expsum (lij)
    4. We load up the running stats for the ith row block. Two vectors of size Br x 1, which denotes the current global row-maxes (mi) and the expsum (li). (Current SRAM usage: 2BcD + BrD + 2Br)
    5. We get the new estimates for the global mi and li.
    6. We load the part of the output for this block of Q and update it using the new running stats and the exponent trick, we then write this back into the HBM. (Current SRAM usage: 2BcD + 2BrD + 2Br)
    7. We write the updated running stats also into the HBM.
  3. For a matrix of any size, aka any context length, at a time we will never materialize the full attention matrix, only a part of it always.
  4. We managed to fuse together all the ops into a single kernel, reducing HBM access considerably.

Final SRAM usage stands although at 4BD + 2B, where B was initially calculated as M/4d where M is the SRAM capacity. Not sure if am missing something here. Please comment if you know why this is the case!

Block Sparse Attention and V2 and V3

I will keep this short as these versions keep the core idea but figured out better and better ways to do the same.

For Block Sparse Attention,

  1. Consider we had masks for each block like in the case of causal attention. If for a given block we have the masks all set to zero then we can simply skip the entire block without computing anything really. Saving FLOPs. This is where the major gains were seen. To put this into perspective, in the case of BERT pre-training the algorithm gets a 15% boost over the best performing training setup at the time, whereas for GPT-2 we get a 3x over huggingface training implementation and ~ 2x over a Megatron setup.
Performance gain for autoregressive models, where we have a sparse mask. Source: Tri Dao et.al [2]

2. You can literally get the same performance in GPT2 in a fraction of the time, literally shaving off days from the training run, which is awesome!

In V2:

  1. Notice how currently we can only do parallelization at the batch and head dimension. But if you simply just flip the order to look at all the column blocks for a given row block then we get the following advantages:
    1. Each row block becomes embarrassingly parallel. How you know this is by looking at the illustrations above. You need all the column blocks for a given row block to fully form the attention output. If you were to run all the column blocks in parallel, you will end up with a race condition that will try to update the same rows of the output at the same time. But not if you do it the other way around. Although there are atomic add operators in triton which could help, they may potentially set us back.
    2. We can avoid hitting the HBM to get the global Ms and Ls. We can initialize one on the chip for each kernel.
    3. Also we do not have to scale all the output update terms with the new estimate of L. We can just compute stuff without dividing by L and at the end of all the column blocks, simply divide the output with the latest estimate of L, saving some FLOPS again!
  2. Much of the improvement also comes in the form of the backward kernel. I am omitting all the backward kernels from this. But they are a fun exercise to try and implement, although they are significantly more complex.

Here are some benchmarks:

Performance benchmark of FA v2 against existing attention algorithms. Source: Tri Dao et.al [3]

The actual implementations of these kernels need to take into account various nuances that we encounter in the real world. I have tried to keep it simple. But do check them out here.

More recently in V3:

  1. Newer GPUs, especially the Hopper and Blackwell GPUs, have low precision modes (FP8 in Hopper and GP4 in Blackwell), which can double and quadruple the throughput for the same power and chip area and more specialized GEMM (General Matrix Multiply) kernels, which the previous version of the algorithm fails to capitalize on. This is because there are many operations which are non-GEMM, like softmax, which reduces the utilization of these specialized GPU kernels.
  2. The FA v1 and v2 are essentially synchronous. Recall in the v2 description I mentioned that we are limited when column blocks try to write to the same output pointers, or when we have to go step by step using the output from the previous steps. Well these modern GPUs can make use special instructions to break this synchrony.

We overlap the comparatively low-throughput non-GEMM operations involved in softmax, such as floating point multiply-add and exponential, with the asynchronous WGMMA instructions for GEMM. As part of this, we rework the FlashAttention-2 algorithm to circumvent certain sequential dependencies between softmax and the GEMMs. For example, in the 2-stage version of our algorithm, while softmax executes on one block of the scores matrix, WGMMA executes in the asynchronous proxy to compute the next block.

Flash Attention v3, Shah et.al
  1. They also adapted the algorithm to target these specialized low precision Tensor cores on these new devices, significantly increasing the FLOPs.

Some more benchmarks:

FA v3 Performance gain over v2. Source: Shah et. al [5]

Conclusion

There is much to admire in their work here. The floor for this technical skill level often seemed high owing to the low level details. But hopefully tools like Triton could change the game and get more people into this! The future is bright.

References

[1] Qwen 2.5-7B-Instruct-1M Huggingface Model Page

[2] Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Re, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

[3] Tri Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

[4] NVIDIA Hopper Architecture Page

[5] Jay ShahGanesh BikshandiYing ZhangVijay ThakkarPradeep RamaniTri Dao, FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

[6] Single-precision floating-point format, Wikipedia

The post Kernel Case Study: Flash Attention appeared first on Towards Data Science.

]]>
The Case for Centralized AI Model Inference Serving https://towardsdatascience.com/the-case-for-centralized-ai-model-inference-serving/ Wed, 02 Apr 2025 01:52:26 +0000 https://towardsdatascience.com/?p=605383 Optimizing highly parallel AI algorithm execution

The post The Case for Centralized AI Model Inference Serving appeared first on Towards Data Science.

]]>
As AI models continue to increase in scope and accuracy, even tasks once dominated by traditional algorithms are gradually being replaced by Deep Learning models. Algorithmic pipelines — workflows that take an input, process it through a series of algorithms, and produce an output — increasingly rely on one or more AI-based components. These AI models often have significantly different resource requirements than their classical counterparts, such as higher memory usage, reliance on specialized hardware accelerators, and increased computational demands.

In this post, we address a common challenge: efficiently processing large-scale inputs through algorithmic pipelines that include deep learning models. A typical solution is to run multiple independent jobs, each responsible for processing a single input. This setup is often managed with job orchestration frameworks (e.g., Kubernetes). However, when deep learning models are involved, this approach can become inefficient as loading and executing the same model in each individual process can lead to resource contention and scaling limitations. As AI models become increasingly prevalent in algorithmic pipelines, it is crucial that we revisit the design of such solutions.

In this post we evaluate the benefits of centralized Inference serving, where a dedicated inference server handles prediction requests from multiple parallel jobs. We define a toy experiment in which we run an image-processing pipeline based on a ResNet-152 image classifier on 1,000 individual images. We compare the runtime performance and resource utilization of the following two implementations:

  1. Decentralized inference — each job loads and runs the model independently.
  2. Centralized inference — all jobs send inference requests to a dedicated inference server.

To keep the experiment focused, we make several simplifying assumptions:

  • Instead of using a full-fledged job orchestrator (like Kubernetes), we implement parallel process execution using Python’s multiprocessing module.
  • While real-world workloads often span multiple nodes, we run everything on a single node.
  • Real-world workloads typically include multiple algorithmic components. We limit our experiment to a single component — a ResNet-152 classifier running on a single input image.
  • In a real-world use case, each job would process a unique input image. To simplify our experiment setup, each job will process the same kitty.jpg image.
  • We will use a minimal deployment of a TorchServe inference server, relying mostly on its default settings. Similar results are expected with alternative inference server solutions such as NVIDIA Triton Inference Server or LitServe.

The code is shared for demonstrative purposes only. Please do not interpret our choice of TorchServe — or any other component of our demonstration — as an endorsement of its use.

Toy Experiment

We conduct our experiments on an Amazon EC2 c5.2xlarge instance, with 8 vCPUs and 16 GiB of memory, running a PyTorch Deep Learning AMI (DLAMI). We activate the PyTorch environment using the following command:

source /opt/pytorch/bin/activate

Step 1: Creating a TorchScript Model Checkpoint

We begin by creating a ResNet-152 model checkpoint. Using TorchScript, we serialize both the model definition and its weights into a single file:

import torch
from torchvision.models import resnet152, ResNet152_Weights

model = resnet152(weights=ResNet152_Weights.DEFAULT)
model = torch.jit.script(model)
model.save("resnet-152.pt")

Step 2: Model Inference Function

Our inference function performs the following steps:

  1. Load the ResNet-152 model.
  2. Load an input image.
  3. Preprocess the image to match the input format expected by the model, following the implementation defined here.
  4. Run inference to classify the image.
  5. Post-process the model output to return the top five label predictions, following the implementation defined here.

We define a constant MAX_THREADS hyperparameter that we use to restrict the number of threads used for model inference in each process. This is to prevent resource contention between the multiple jobs.

import os, time, psutil
import multiprocessing as mp
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image


def predict(image_id):
    # Limit each process to 1 thread
    MAX_THREADS = 1
    os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
    os.environ["MKL_NUM_THREADS"] = str(MAX_THREADS)
    torch.set_num_threads(MAX_THREADS)

    # load the model
    model = torch.jit.load('resnet-152.pt').eval()

    # Define image preprocessing steps
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])

    # load the image
    image = Image.open('kitten.jpg').convert("RGB")
    
    # preproc
    image = transform(image).unsqueeze(0)

    # perform inference
    with torch.no_grad():
        output = model(image)

    # postproc
    probabilities = F.softmax(output[0], dim=0)
    probs, classes = torch.topk(probabilities, 5, dim=0)
    probs = probs.tolist()
    classes = classes.tolist()

    return dict(zip(classes, probs))

Step 3: Running Parallel Inference Jobs

We define a function that spawns parallel processes, each processing a single image input. This function:

  • Accepts the total number of images to process and the maximum number of concurrent jobs.
  • Dynamically launches new processes when slots become available.
  • Monitors CPU and memory usage throughout execution.
def process_image(image_id):
    print(f"Processing image {image_id} (PID: {os.getpid()})")
    predict(image_id)

def spawn_jobs(total_images, max_concurrent):
    start_time = time.time()
    max_mem_utilization = 0.
    max_utilization = 0.

    processes = []
    index = 0
    while index < total_images or processes:

        while len(processes) < max_concurrent and index < total_images:
            # Start a new process
            p = mp.Process(target=process_image, args=(index,))
            index += 1
            p.start()
            processes.append(p)

        # sample memory utilization
        mem_usage = psutil.virtual_memory().percent
        max_mem_utilization = max(max_mem_utilization, mem_usage)
        cpu_util = psutil.cpu_percent(interval=0.1)
        max_utilization = max(max_utilization, cpu_util)

        # Remove completed processes from list
        processes = [p for p in processes if p.is_alive()]

    total_time = time.time() - start_time
    print(f"\nTotal Processing Time: {total_time:.2f} seconds")
    print(f"Max CPU Utilization: {max_utilization:.2f}%")
    print(f"Max Memory Utilization: {max_mem_utilization:.2f}%")

spawn_jobs(total_images=1000, max_concurrent=32)

Estimating the Maximum Number of Processes

While the optimal number of maximum concurrent processes is best determined empirically, we can estimate an upper bound based on the 16 GiB of system memory and the size of the resnet-152.pt file, 231 MB.

The table below summarizes the runtime results for several configurations:

Decentralized Inference Results (by Author)

Although memory becomes fully saturated at 50 concurrent processes, we observe that maximum throughput is achieved at 8 concurrent jobs — one per vCPU. This indicates that beyond this point, resource contention outweighs any potential gains from additional parallelism.

The Inefficiencies of Independent Model Execution

Running parallel jobs that each load and execute the model independently introduces significant inefficiencies and waste:

  1. Each process needs to allocate the appropriate memory resources for storing its own copy of the AI model.
  2. AI models are compute-intensive. Executing them in many processes in parallel can lead to resource contention and reduced throughput.
  3. Loading the model checkpoint file and initializing the model in each process adds overhead and can further increase latency. In the case of our toy experiment, model initialization makes up for roughly 30%(!!) of the overall inference processing time.

A more efficient alternative is to centralize inference execution using a dedicated model inference server. This approach would eliminate redundant model loading and reduce overall system resource utilization.

In the next section we will set up an AI model inference server and assess its impact on resource utilization and runtime performance.

Note: We could have modified our multiprocessing-based approach to share a single model across processes (e.g., using torch.multiprocessing or another solution based on shared memory). However, the inference server demonstration better aligns with real-world production environments, where jobs often run in isolated containers.

TorchServe Setup

The TorchServe setup described in this section loosely follows the resnet tutorial. Please refer to the official TorchServe documentation for more in-depth guidelines.

Installation

The PyTorch environment of our DLAMI comes preinstalled with TorchServe executables. If you are running in a different environment run the following installation command:

pip install torchserve torch-model-archiver

Creating a Model Archive

The TorchServe Model Archiver packages the model and its associated files into a “.mar” file archive, the format required for deployment on TorchServe. We create a TorchServe model archive file based on our model checkpoint file and using the default image_classifier handler:

mkdir model_store
torch-model-archiver \
    --model-name resnet-152 \
    --serialized-file resnet-152.pt \
    --handler image_classifier \
    --version 1.0 \
    --export-path model_store

TorchServe Configuration

We create a TorchServe config.properties file to define how TorchServe should operate:

model_store=model_store
load_models=resnet-152.mar
models={\
  "resnet-152": {\
    "1.0": {\
        "marName": "resnet-152.mar"\
    }\
  }\
}

# Number of workers per model
default_workers_per_model=1

# Job queue size (default is 100)
job_queue_size=100

After completing these steps, our working directory should look like this:

├── config.properties
֫├── kitten.jpg
├── model_store
│   ├── resnet-152.mar
├── multi_job.py

Starting TorchServe

In a separate shell we start our TorchServe inference server:

source /opt/pytorch/bin/activate
torchserve \
    --start \
    --disable-token-auth \
    --enable-model-api \
    --ts-config config.properties

Inference Request Implementation

We define an alternative prediction function that calls our inference service:

import requests

def predict_client(image_id):
    with open('kitten.jpg', 'rb') as f:
        image = f.read()
    response = requests.post(
        "http://127.0.0.1:8080/predictions/resnet-152",
        data=image,
        headers={'Content-Type': 'application/octet-stream'}
    )

    if response.status_code == 200:
        return response.json()
    else:
        print(f"Error from inference server: {response.text}")

Scaling Up the Number of Concurrent Jobs

Now that inference requests are being processed by a central server, we can scale up parallel processing. Unlike the earlier approach where each process loaded and executed its own model, we have sufficient CPU resources to allow for many more concurrent processes. Here we choose 100 processes in accordance with the default job_queue_size capacity of the inference server:

spawn_jobs(total_images=1000, max_concurrent=100)

Results

The performance results are captured in the table below. Keep in mind that the comparative results can vary greatly based on the details of the AI model and the runtime environment.

Inference Server Results (by Author)

By using a centralized inference server, not only have we have increased overall throughput by more than 2X, but we have freed significant CPU resources for other computation tasks.

Next Steps

Now that we have effectively demonstrated the benefits of a centralized inference serving solution, we can explore several ways to enhance and optimize the setup. Recall that our experiment was intentionally simplified to focus on demonstrating the utility of inference serving. In real-world deployments, additional enhancements may be required to tailor the solution to your specific needs.

  1. Custom Inference Handlers: While we used TorchServe’s built-in image_classifier handler, defining a custom handler provides much greater control over the details of the inference implementation.
  2. Advanced Inference Server Configuration: Inference server solutions will typically include many features for tuning the service behavior according to the workload requirements. In the next sections we will explore some of the features supported by TorchServe.
  3. Expanding the Pipeline: Real world models will typically include more algorithm blocks and more sophisticated AI models than we used in our experiment.
  4. Multi-Node Deployment: While we ran our experiments on a single compute instance, production setups will typically include multiple nodes.
  5. Alternative Inference Servers: While TorchServe is a popular choice and relatively easy to set up, there are many alternative inference server solutions that may provide additional benefits and may better suit your needs. Importantly, it was recently announced that TorchServe would no longer be actively maintained. See the documentation for details.
  6. Alternative Orchestration Frameworks: In our experiment we use Python multiprocessing. Real-world workloads will typically use more advanced orchestration solutions.
  7. Utilizing Inference Accelerators: While we executed our model on a CPU, using an AI accelerator (e.g., an NVIDIA GPU, a Google Cloud TPU, or an AWS Inferentia) can drastically improve throughput.
  8. Model OptimizationOptimizing your AI models can greatly increase efficiency and throughput.
  9. Auto-Scaling for Inference Load: In some use cases inference traffic will fluctuate, requiring an inference server solution that can scale its capacity accordingly.

In the next sections we explore two simple ways to enhance our TorchServe-based inference server implementation. We leave the discussion on other enhancements to future posts.

Batch Inference with TorchServe

Many model inference service solutions support the option of grouping inference requests into batches. This usually results in increased throughput, especially when the model is running on a GPU.

We extend our TorchServe config.properties file to support batch inference with a batch size of up to 8 samples. Please see the official documentation for details on batch inference with TorchServe.

model_store=model_store
load_models=resnet-152.mar
models={\
  "resnet-152": {\
    "1.0": {\
        "marName": "resnet-152.mar",\
        "batchSize": 8,\
        "maxBatchDelay": 100,\
        "responseTimeout": 200\
    }\
  }\
}

# Number of workers per model
default_workers_per_model=1

# Job queue size (default is 100)
job_queue_size=100

Results

We append the results in the table below:

Batch Inference Server Results (by Author)

Enabling batched inference increases the throughput by an additional 26.5%.

Multi-Worker Inference with TorchServe

Many model inference service solutions will support creating multiple inference workers for each AI model. This enables fine-tuning the number of inference workers based on expected load. Some solutions support auto-scaling of the number of inference workers.

We extend our own TorchServe setup by increasing the default_workers_per_model setting that controls the number of inference workers assigned to our image classification model.

Importantly, we must limit the number of threads allocated to each worker to prevent resource contention. This is controlled by the number_of_netty_threads setting and by the OMP_NUM_THREADS and MKL_NUM_THREADS environment variables. Here we have set the number of threads to equal the number of vCPUs (8) divided by the number of workers.

model_store=model_store
load_models=resnet-152.mar
models={\
  "resnet-152": {\
    "1.0": {\
        "marName": "resnet-152.mar"\
        "batchSize": 8,\
        "maxBatchDelay": 100,\
        "responseTimeout": 200\
    }\
  }\
}

# Number of workers per model
default_workers_per_model=2 

# Job queue size (default is 100)
job_queue_size=100

# Number of threads per worker
number_of_netty_threads=4

The modified TorchServe startup sequence appears below:

export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4
torchserve \
    --start \
    --disable-token-auth \
    --enable-model-api \
    --ts-config config.properties

Results

In the table below we append the results of running with 2, 4, and 8 inference workers:

Multi-Worker Inference Server Results (by Author)

By configuring TorchServe to use multiple inference workers, we are able to increase the throughput by an additional 36%. This amounts to a 3.75X improvement over the baseline experiment.

Summary

This experiment highlights the potential impact of inference server deployment on multi-job deep learning workloads. Our findings suggest that using an inference server can improve system resource utilization, enable higher concurrency, and significantly increase overall throughput. Keep in mind that the precise benefits will greatly depend on the details of the workload and the runtime environment.

Designing the inference serving architecture is just one part of optimizing AI model execution. Please see some of our many posts covering a wide range AI model optimization techniques.

The post The Case for Centralized AI Model Inference Serving appeared first on Towards Data Science.

]]>
A Simple Implementation of the Attention Mechanism from Scratch https://towardsdatascience.com/a-simple-implementation-of-the-attention-mechanism-from-scratch/ Tue, 01 Apr 2025 01:05:51 +0000 https://towardsdatascience.com/?p=605368 How attention helped models like RNNs mitigate the vanishing gradient problem and capture long-range dependencies among words

The post A Simple Implementation of the Attention Mechanism from Scratch appeared first on Towards Data Science.

]]>
Introduction

The Attention Mechanism is often associated with the transformer architecture, but it was already used in RNNs. In Machine Translation or MT (e.g., English-Italian) tasks, when you want to predict the next Italian word, you need your model to focus, or pay attention, on the most important English words that are useful to make a good translation.

Attention in RNNs

I will not go into details of RNNs, but attention helped these models to mitigate the vanishing gradient problem and to capture more long-range dependencies among words.

At a certain point, we understood that the only important thing was the attention mechanism, and the entire RNN architecture was overkill. Hence, Attention is All You Need!

Self-Attention in Transformers

Classical attention indicates where words in the output sequence should focus attention in relation to the words in input sequence. This is important in sequence-to-sequence tasks like MT.

The self-attention is a specific type of attention. It operates between any two elements in the same sequence. It provides information on how “correlated” the words are in the same sentence.

For a given token (or word) in a sequence, self-attention generates a list of attention weights corresponding to all other tokens in the sequence. This process is applied to each token in the sentence, obtaining a matrix of attention weights (as in the picture).

This is the general idea, in practice things are a bit more complicated because we want to add many learnable parameters to our neural network, let’s see how.

K, V, Q representations

Our model input is a sentence like “my name is Marcello Politi”. With the process of tokenization, a sentence is converted into a list of numbers like [2, 6, 8, 3, 1].

Before feeding the sentence into the transformer we need to create a dense representation for each token.

How to create this representation? We multiply each token by a matrix. The matrix is learned during training.

Let’s add some complexity now.

For each token, we create 3 vectors instead of one, we call these vectors: key, value and query. (We see later how we create these 3 vectors).

Conceptually these 3 tokens have a particular meaning:

  • The vector key represents the core information captured by the token
  • The vector value captures the full information of a token
  • The vector query, it’s a question about the token relevance for the current task.

So the idea is that we focus on a particular token i , and we want to ask what is the importance of the other tokens in the sentence regarding the token i we are taking into consideration.

This means that we take the vector q_i (we ask a question regarding i) for token i, and we do some mathematical operations with all the other tokens k_j (j!=i). This is like wondering at first glance what are the other tokens in the sequence that look really important to understand the meaning of token i.

What is this magical mathematical operation?

We need to multiply (dot-product) the query vector by the key vectors and divide by a scaling factor. We do this for each k_j token.

In this way, we obtain a score for each pair (q_i, k_j). We make this list become a probability distribution by applying a softmax operation on it. Great now we have obtained the attention weights!

With the attention weights, we know what is the importance of each token k_j to for undestandin the token i. So now we multiply the value vector v_j associated with each token per its weight and we sum the vectors. In this way we obtain the final context-aware vector of token_i.

If we are computing the contextual dense vector of token_1 we calculate:

z1 = a11*v1 + a12*v2 + … + a15*v5

Where a1j are the computer attention weights, and v_j are the value vectors.

Done! Almost…

I didn’t cover how we obtained the vectors k, v and q of each token. We need to define some matrices w_k, w_v and w_q so that when we multiply:

  • token * w_k -> k
  • token * w_q -> q
  • token * w_v -> v

These 3 matrices are set at random and are learned during training, this is why we have many parameters in modern models such as LLMs.

Multi-head Self-Attention in Transformers (MHSA)

Are we sure that the previous self-attention mechanism is able to capture all important relationships among tokens (words) and create dense vectors of those tokens that really make sense?

It could actually not work always perfectly. What if to mitigate the error we re-run the entire thing 2 times with new w_q, w_k and w_v matrices and somehow merge the 2 dense vectors obtained? In this way maybe one self-attention managed to capture some relationship and the other managed to capture some other relationship.

Well, this is what exactly happens in MHSA. The case we just discussed contains two heads because it has two sets of w_q, w_k and w_v matrices. We can have even more heads: 4, 8, 16 etc.

The only complicated thing is that all these heads are managed in parallel, we process the all in the same computation using tensors.

The way we merge the dense vectors of each head is simple, we concatenate them (hence the dimension of each vector shall be smaller so that when concat them we obtain the original dimension we wanted), and we pass the obtained vector through another w_o learnable matrix.

Hands-on

Python">import torch

Suppose you have a sentence. After tokenization, each token (word for simplicity) corresponds to an index (number):

tokenized_sentence = torch.tensor([
    2, #my
    6, #name
    8, #is
    3, #marcello
    1  #politi
])
tokenized_sentence

Before feeding the sentence into the transofrmer we need to create a dense representation for each token.

How to create these representation? We multiply each token per a matrix. This matrix is learned during training.

Let’s build this embedding matrix.

torch.manual_seed(0) # set a fixed seed for reproducibility
embed = torch.nn.Embedding(10, 16)

If we multiply our tokenized sentence with the embeddings, we obtain a dense representation of dimension 16 for each token

sentence_embed = embed(tokenized_sentence).detach()
sentence_embed

In order to use the attention mechanism we need to create 3 new We define 3 matrixes w_q, w_k and w_v. When we multiply one input token time the w_q we obtain the vector q. Same with w_k and w_v.

d = sentence_embed.shape[1] # let's base our matrix on a shape (16,16)

w_key = torch.rand(d,d)
w_query = torch.rand(d,d)
w_value = torch.rand(d,d)

Compute attention weights

Let’s now compute the attention weights for only the first input token of the sentence.

token1_embed = sentence_embed[0]

#compute the tre vector associated to token1 vector : q,k,v
key_1 = w_key.matmul(token1_embed)
query_1 = w_query.matmul(token1_embed)
value_1 = w_value.matmul(token1_embed)

print("key vector for token1: \n", key_1)   
print("query vector for token1: \n", query_1)
print("value vector for token1: \n", value_1)

We need to multiply the query vector associated to token1 (query_1) with all the keys of the other vectors.

So now we need to compute all the keys (key_2, key_2, key_4, key_5). But wait, we can compute all of these in one time by multiplying the sentence_embed times the w_k matrix.

keys = sentence_embed.matmul(w_key.T)
keys[0] #contains the key vector of the first token and so on

Let’s do the same thing with the values

values = sentence_embed.matmul(w_value.T)
values[0] #contains the value vector of the first token and so on

Let’s compute the first part of the attions formula.

import torch.nn.functional as F
# the following are the attention weights of the first tokens to all the others
a1 = F.softmax(query_1.matmul(keys.T)/d**0.5, dim = 0)
a1

With the attention weights we know what is the importance of each token. So now we multiply the value vector associated to each token per its weight.

To obtain the final context aware vector of token_1.

z1 = a1.matmul(values)
z1

In the same way we could compute the context aware dense vectors of all the other tokens. Now we are always using the same matrices w_k, w_q, w_v. We say that we use one head.

But we can have multiple triplets of matrices, so multi-head. That’s why it is called multi-head attention.

The dense vectors of an input tokens, given in oputut from each head are at then end concatenated and linearly transformed to get the final dense vector.

Implementing MultiheadSelf-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0) # fixed seed for reproducibility

Same steps as before…

# Tokenized sentence (same as yours)
tokenized_sentence = torch.tensor([2, 6, 8, 3, 1])  # [my, name, is, marcello, politi]

# Embedding layer: vocab size = 10, embedding dim = 16
embed = nn.Embedding(10, 16)
sentence_embed = embed(tokenized_sentence).detach()  # Shape: [5, 16] (seq_len, embed_dim)

We’ll define a multi-head attention mechanism with h heads (let’s say 4 heads for this example). Each head will have its own w_q, w_k, and w_v matrices, and the output of each head will be concatenated and passed through a final linear layer.

Since the output of the head will be concatenated, and we want a final dimension of d, the dimension of each head needs to be d/h. Additionally each concatenated vector will go though a linear transformation, so we need another matrix w_ouptut as you can see in the formula.

d = sentence_embed.shape[1]  # embed dimension 16
h = 4  # Number of heads
d_k = d // h  # Dimension per head (16 / 4 = 4)

Since we have 4 heads, we want 4 copies for each matrix. Instead of copies, we add a dimension, which is the same thing, but we only do one operation. (Imagine stacking matrices on top of each other, its the same thing).

# Define weight matrices for each head
w_query = torch.rand(h, d, d_k)  # Shape: [4, 16, 4] (one d x d_k matrix per head)
w_key = torch.rand(h, d, d_k)    # Shape: [4, 16, 4]
w_value = torch.rand(h, d, d_k)  # Shape: [4, 16, 4]
w_output = torch.rand(d, d)  # Final linear layer: [16, 16]

I’m using for simplicity torch’s einsum. If you’re not familiar with it check out my blog post.

The einsum operation torch.einsum('sd,hde->hse', sentence_embed, w_query) in PyTorch uses letters to define how to multiply and rearrange numbers. Here’s what each part means:

  1. Input Tensors:
    • sentence_embed with the notation 'sd':
      • s represents the number of words (sequence length), which is 5.
      • d represents the number of numbers per word (embedding size), which is 16.
      • The shape of this tensor is [5, 16].
    • w_query with the notation 'hde':
      • h represents the number of heads, which is 4.
      • d represents the embedding size, which again is 16.
      • e represents the new number size per head (d_k), which is 4.
      • The shape of this tensor is [4, 16, 4].
  2. Output Tensor:
    • The output has the notation 'hse':
      • h represents 4 heads.
      • s represents 5 words.
      • e represents 4 numbers per head.
      • The shape of the output tensor is [4, 5, 4].
# Compute Q, K, V for all tokens and all heads
# sentence_embed: [5, 16] -> Q: [4, 5, 4] (h, seq_len, d_k)
queries = torch.einsum('sd,hde->hse', sentence_embed, w_query)  # h heads, seq_len tokens, d dim
keys = torch.einsum('sd,hde->hse', sentence_embed, w_key)       # h heads, seq_len tokens, d dim
values = torch.einsum('sd,hde->hse', sentence_embed, w_value)   # h heads, seq_len tokens, d dim

This einsum equation performs a dot product between the queries (hse) and the transposed keys (hek) to obtain scores of shape [h, seq_len, seq_len], where:

  • h -> Number of heads.
  • s and k -> Sequence length (number of tokens).
  • e -> Dimension of each head (d_k).

The division by (d_k ** 0.5) scales the scores to stabilize gradients. Softmax is then applied to obtain attention weights:

# Compute attention scores
scores = torch.einsum('hse,hek->hsk', queries, keys.transpose(-2, -1)) / (d_k ** 0.5)  # [4, 5, 5]
attention_weights = F.softmax(scores, dim=-1)  # [4, 5, 5]
# Apply attention weights
head_outputs = torch.einsum('hij,hjk->hik', attention_weights, values)  # [4, 5, 4]
head_outputs.shape

Now we concatenate all the heads of token 1

# Concatenate heads
concat_heads = head_outputs.permute(1, 0, 2).reshape(sentence_embed.shape[0], -1)  # [5, 16]
concat_heads.shape

Let’s finally multiply per the last w_output matrix as in the formula above

multihead_output = concat_heads.matmul(w_output)  # [5, 16] @ [16, 16] -> [5, 16]
print("Multi-head attention output for token1:\n", multihead_output[0])

Final Thoughts

In this blog post I’ve implemented a simple version of the attention mechanism. This is not how it is really implemented in modern frameworks, but my scope is to provide some insights to allow anyone an understanding of how this works. In future articles I’ll go through the entire implementation of a transformer architecture.

Follow me on TDS if you like this article! 😁

💼 Linkedin | 🐦 X (Twitter) | 💻 Website


Unless otherwise noted, images are by the author

The post A Simple Implementation of the Attention Mechanism from Scratch appeared first on Towards Data Science.

]]>
My Learning to Be Hired Again After a Year… Part 2 https://towardsdatascience.com/my-learning-to-be-hired-again-after-a-year-part-2/ Mon, 31 Mar 2025 19:24:25 +0000 https://towardsdatascience.com/?p=605343 One year later: what I learned still matters

The post My Learning to Be Hired Again After a Year… Part 2 appeared first on Towards Data Science.

]]>

This is the second part of My learning to being hired again after a year… Part I.

Hard to believe, but it’s been a full year since I published the first part on TDS. And in that time, something beautiful happened. Every so often, someone would leave a comment, highlight a sentence, or send me a message. Most were simple notes like, “Thank you, Amy. Your post helped me.” But those words lit me up. They brightened entire days. They reminded me that I was never truly alone, not during those long months of unemployment, not in the struggle of figuring out who I was without a job title or company name beneath my email signature or Linkedin profile.

Funny enough, those hard days turned out to be some of the most meaningful ones I’ve had. Maybe even more meaningful than my busiest days at work. Because in losing an identity, I found new ones. I didn’t need a job or a title to feel connected. To many of you, I’m just a pretty lazy writer getting back into the groove. And here I am — returning to my writing routine. So, thank you to everyone who reached out. Your messages rank second on my list of happiest things people give me. The first? That’s easy. My daughter Ellie’s three S’s: her smell, her smile, and her surprises.

Enough talk. Let’s get into Part 2. I’ll pick up where I left off — sharing the lessons that helped me get hired again. This time, I’ll also reflect on how those lessons show up in my work and life today. And for those of you curious about the methods from the book Never Search Alone, I’ve got some thoughts on that too. What worked, what didn’t, and how I made it my own.

Knock, Knock: Opportunity’s at the Door — You Won’t Lose a Penny for Trying

A year into working as a Machine Learning Engineer, I can say this was my biggest life lesson.

Here’s the backstory. I’d been working as a data scientist ever since I finished grad school. Over the past 7 years, I’ve built multiple machine learning models, linear regression, neural networks and Xgboost. All solid stuff. But when it came to designing an entire machine learning system from start to finish? That was a different story. I hadn’t really done that. I knew how to develop models, sure. I even had some experience deploying them, but only parts of the process. If you asked me to design, build, and run an entire system end-to-end, I couldn’t say I had that experience.

And the job market? It was changing fast. Companies didn’t want someone who could just build models anymore. Generative AI was handling a lot of the data analysis now. What they really wanted was someone who could take machine learning and use it to solve real business problems, someone who could own the whole process. Meanwhile, I had just been laid off. I had time. So I decided maybe this was the right moment to pivot. Maybe it was time to go for machine learning engineering.

The first thing I did was reach out to people who had already made that move. Two friends said yes. One had gone from data scientist to machine learning engineer. The other was a data scientist, and her husband worked as an MLE at Apple. We ended up having this long phone call for two hours, maybe more. They were kind. And they didn’t sugarcoat anything. Both of them told me it was tough to make the switch. Not impossible, but tough. If you didn’t have MLOps experience or a solid GitHub portfolio to show off, landing a senior MLE job would be really hard. Especially with how competitive things were getting.

That conversation hit hard. I remember feeling my heart pound, like cold water had been poured over my head. I had two options: I could keep chasing data scientist jobs — applied scientist roles at places like Amazon — but there weren’t many out there. Or swallow my pride, let go of seven years of experience as a data scientist and go for an entry-level MLE role. Honestly, neither choice felt great.

It took me two weeks to work through it. Two long long weeks. But in the end, I made up my mind: I’d try for machine learning engineer jobs at least, even if I had to start from the bottom. I got back to my routine and prepped for interviews. During those hard days, I started blogging on Medium and published on TDS to show my technical muscle, sharing my “Courage to Learn ML” series. Ready for a spoiler alert? I ended up with three offers for senior and even staff level machine learning engineering roles. And I had three other final-round interviews lined up that I had to walk away from, because there just wasn’t enough time or energy for me to do them all.

No, none of those offers came from FAANG companies. But I’m more than happy with where I landed. It was worth the try.

Even now, writing this, I can still feel that chill from when my friends told me the odds were slim. And I can still laugh at how panicked I was. Just the other day, I spoke with a friend who’s looking to move from data engineering into MLE. I told him the same thing I learned for myself: You can do it. And if you decide it’s worth trying, don’t get hung up on the odds. Even if it’s a 1% chance, why not see if you’re in that 1%? But if you don’t try at all, you’re 100% in the group that never made it.

For me, the takeaway is simple. Don’t be afraid of probabilities. Even 99.999999% is not 100%. If you’re worried about the outcome, stop thinking about the outcome. Just do it for fun, for your mental health, for the chance to live without regrets.

A Year Later: I use this lesson almost every day. I blog shamelessly, pretending I don’t care about if people really read those. I make those awkward customer service calls, just to see if someone on the other end might actually help me. I even buy a lottery ticket now and then when the jackpot tops a billion dollars. Who knows? I might end up in that 0.0000…001%. And you know what? I recently won $12 on a ticket. So yes — it’s worth trying.

Learning During the Struggle: Don’t Beg for Jobs 

This was another hard lesson from my “to be an MLE or not to be” chapter.

When I spoke with those two friends, they made one thing clear. If I wanted to become a machine learning engineer, I needed hands-on experience with MLOps (machine learning operations). The problem? In my past roles, I’d either handed off my models to software engineers for deployment or handled just one small part of the system myself. I knew I had a gap. And my first instinct was to fill it by any means necessary. So I figured, why not get involved in some real projects? Something complex. Something I could proudly add to my resume.

Since I was out of work, I had time. I joined MLOps communities on Slack and Discord. I posted about my background, offered to work for free with any startup or team that needed help. Just to get some experience in exchange. The response? Pretty discouraging. Hardly anyone replied. A few did, but they expected me to work 50+ hours a week… for free and without any working plans. I remember sending a message to a PhD student after reading his job posting. I told him how I liked his work and wanted to make his product a reality. He didn’t get back with me. He instead changed his posting to say he was seeking experienced MLEs or someone with a PhD. Ouch.

After a few weeks of all that, I was demotivated and burned out. I was pleading for opportunities and it was clear. It was then that I decided to join a Job Search Council (JSC) (I explained JSC in detail in the part 1). We shared the emotional weight of job hunting every Friday. I slowly started letting go of the tension. And that’s when something clicked. I needed to stop pleading for jobs. Instead, I decided to sell what I had.

I rewrote my resume into two versions, one for data scientist roles and the other for MLE roles. I applied for MLE jobs crazily just to increase the chances. But this time around, I approached it differently. I broke down what the hiring managers were actually looking for in an MLE. I saw how all the model building experience I had acquired had actually taught me on debugging, monitoring, and resolving messy business problems. While I didn’t have a lot of MLOps experience, I wasn’t coming from zero. I had a master’s degree in computer science, I was familiar with software development, and I knew data engineering.

In those MLE interviews, I started highlighting those skills. I explained how I applied machine learning to solve business problems, offered subtle hints about my favorite model-training tricks. I showed hiring managers I knew how it felt to run systems into production. I was honest about where I needed to gain more experience. But I made it clear this wasn’t a cold start.

At some point, I stopped acting like a job-beggar and became a salesperson. I wasn’t asking someone to “please hire me. I’m willing to work more and cheaper”. I was selling something. When a company didn’t hire me, it wasn’t a rejection. It just meant they didn’t need someone like me. Maybe I need to tighten the pitch next time.

This mental shift made all the difference. Negative feedback wasn’t personal anymore. It was just feedback, a little data point I could use to make adjustments. When you ask for something, people think less of you. But when you treat yourself as a product, you’re refining and searching for the right buyers. If there’s a flaw, you fix it. If there are good things, you point them out. And sooner or later, you find your people.

A Year Later: I don’t beg anymore. Not for jobs. Not for opportunities. I exchange. I sell. That mindset has become part of me now. It’s my inner tiny salesperson. 

Mock Interviews and the Interview Marathon: Practice Really Does Make a Difference

I’ll be straight with you. Before I started interviewing for machine learning engineer roles after my layoff, I had never really practiced behavioral interviews. Not once in my seven years of working. Sure, I wrote out a few stories using the STAR method, like everyone says you should. But I never practiced them out loud, and I definitely never got feedback. It was like stepping on stage to perform in a play without ever going to rehearsal. I never realized how big a mistake that was, probably because, back when the job market was good, I didn’t have to.

But after the layoff? After spending nearly a year at home because of pregnancy? The market was ice cold. There weren’t many chances, and I couldn’t afford to blow any of them. I had to nail the behavioral interviews. Not just by memorizing my stories, but by actually practicing. For real.

So, I made my husband do mock interviews with me. I sat in one room, he sat in another, and we jumped on Zoom like it was the real thing. Poor guy — he’s been at the same job since forever and works in a totally different field, but there he was, asking me random behavioral questions. At first, I didn’t think it was going to help. I figured he didn’t get what I did anyway. But when I started answering with my “well-crafted” stories, something surprising happened. I got nervous. And wordy. Way too wordy.

And then he cut me off. Not gently, either. He told me straight up: I was spending way too much time talking about the background. The company, the project, all the setup. He said by the time I got to the part about what I actually did, he had already tuned out. You know what? He was 100% correct and I’d never noticed it before. I never thought about how much time I was wasting on details that didn’t really matter to the person listening.

After that, I went back through my stories. Almost all of them had the same problem. Too much setup, not enough focus on action and results. Honestly? I was grateful for his brutal feedback. It was a little embarrassing, but I wished I’d done mock interviews like that years ago.

From then on, I decided to practice a lot more. With my new MLE resume ready, I started applying like crazy. Interviews came in, and instead of trying to avoid them, I leaned in. Earlier in my career, I was the kind of person who’d grab the first offer just to escape the stress of interviewing. Selling myself has always made me a little panicky. After all, I’m an introvert. But this time, things were different. The book Never Search Alone and those early mock interviews changed my mindset. (I’ll talk more about the book and why it prevents me from rushing out of the interview process later.)

So I gave myself time. I said yes to almost every interview I could get. At one point, I interviewed with four companies over three days. It felt like a marathon, but somewhere along the way, I got good at telling my story. I watched how the interviewers reacted. I collected feedback from the process. And something strange happened: I stopped caring so much about the results. Whether I got a yes or a no didn’t shake me anymore. I wasn’t just interviewing to get a job. I was practicing to get the job I really wanted.

By the time I had three offers on the table and finally chose the one I liked, I knew I was done. That was my finish line. It felt like I’d run the full race and actually won the prize I wanted not the one I settled for.

Seriously, I can’t say this enough: KEEP interviewing. Back-to-back if you can. Do mock interviews with whoever you trust, even if they aren’t in your field. Practice until you’re less worried about the outcome and more focused on getting better.

A Year Later: It’s hard to say how much of those interview skills I still have in me now. But if I ever need to practice again, you better believe I’ll be dragging my husband back into another round of mock interviews. Maybe even for business presentations. He’s a tough crowd, but he gets results :]

Panic Mode? Deep Breath, the Show Must Go On

During my interview marathon, I started noticing something that completely threw me off. Some interviewers looked… disappointed. Others seemed bored. And me? I cared. A lot. Probably too much. Every time I saw a face that wasn’t smiling or nodding, I panicked. In my head, I’d hear this loud voice saying, “Amy, you’re blowing it.” And once that thought crept in, it was over. My brain and body would scramble to fix the situation, so I’d start talking faster, throwing out more words, hoping to change their minds. I wanted to come across as sharp and impressive. But the truth is, I probably looked like a nervous, rambling mess. 

My husband confirmed it after one of our mock interviews. He didn’t sugarcoat it. “You’re not even looking at the camera,” he said. “And you seem really tense.” Again, he is the right.

For an introvert like me, fixing this wasn’t easy. But I found two things that helped. So I will share it here. 

The first was simple: breathe. Every time I spotted what I thought was a bad reaction, a frown, a yawn, that blank expression that felt like doom, I forced myself to pause. I took a breath. And instead of rushing to say more, I slowed down. Sometimes I even cracked a cold joke. (I’m surprisingly good at bad jokes. It might be my secret talent.) Then I’d apologize for the joke, take another breath, and move on. That little reset worked in two ways. First, it quieted the voice in my head screaming “You’re ruining this!” Secondly, it made the interviewer’s expression change. Maybe they smiled and got the joke. Maybe they just looked confused and didn’t like it. But at least they weren’t bored or disappointed anymore. I’ll take that.

The second thing I did was tape a picture of my daughter right behind the camera. Her big, shiny smile was right there, and every time I glanced at it, I smiled too. Which, by the way, made me look more relaxed and human on camera. Sometimes the interviewer smiled back, and just like that, the energy shifted. I wasn’t panicking anymore. I was back in control. The show was back on.

I started thinking of myself as a salesperson. Or maybe a showman. What do they do when the audience looks tired or distracted? They keep going. They adjust. They bring the energy back. If you’re like me, someone who takes those reactions personally, you need to have a plan. These were my two tricks. You’ll probably find your own. But the point is: don’t panic. Pause. Breathe. No one will notice. And then, get back to the show.

A Year Later: Honestly, this might be the most important skill I picked up during that tough year. I still use it all the time at work. When I’m presenting my work to a room full of people, I slow myself down. I picture myself in a fancy tailcoat, like an old-school showman, selling my ideas to the audience. Sometimes I throw in one of my classic cold jokes to keep things light.

When I wrap up a presentation, I make sure to give people something easy to take with them. I’ll say, “If you’re heading out and want one thing to remember about this project, here’s the punchline.” Then I boil it down to one or two sentences and say it clearly. Loud enough to stick.

I even use this trick in regular conversations, especially the awkward ones. A little pause makes everything less uncomfortable. And more often than not, things turn out better after that moment to reset.

Do the Mnookin Two-Pager exercise: How I Found a Job That Actually Fit Me

I keep mentioning the book Never Search Alone, and there’s a reason for that. When I first heard about it, I was skeptical. As an introvert, the idea of joining a group of strangers to talk about job hunting made me extremely uncertain and nervous. 

My first group didn’t go well. There were five of us, but two people refused to follow the process. They were often late or skipped meetings entirely. It was frustrating, and I almost gave up. Instead, I found another group through the Slack community. That time, it clicked. We met every Friday, and kept each other accountable. We helped one another stay sane through the search. It made a huge difference. If you want to know more about how the JSC (Job Search Council) helped me, I wrote about it in part one of this story.

Looking back, another useful thing the book offered was the Mnookin Two-Pager exercise. You sit down and write out what you love in a job, what you hate, and what your career goals are. Simple, but surprisingly powerful. It forced me to get honest with myself. Without it, I probably would have grabbed the very first offer and rushed out of the market, just to be done with it. I’ve done that before. And regretted it.

This time was different. My two pager list kept me grounded. I knew what I wanted and where I wasn’t willing to settle. That’s how I ended up at Disney. The role hits about 85% of what I was hoping for. More importantly, it steers clear of every red flag on my “hard no” list. A year later, I’m still glad I took the time to figure out exactly what I was looking for before saying yes to anything.


Finally! We Made It to the End. 

I’m so glad I finally sat down and finished this. Honestly, I’m the kind of person who thinks a lot. But writing things out like this helps me clear my head and hold on to the lessons I actually want to keep.

If you’ve enjoyed reading this, and you want to read more stories from me, or you just want to smile at how bad my jokes are, please keep an eye on my posts on TDS. Or better yet, subscribe to my newsletter where I write more frequently about AI and ML, along with life lessons, parenting, and, of course, a few of my cold jokes.! If you’d like to support my writing, you can also just buy me a coffee on https://ko-fi.com/amyma101! ☕✨

The post My Learning to Be Hired Again After a Year… Part 2 appeared first on Towards Data Science.

]]>
Uncertainty Quantification in Machine Learning with an Easy Python Interface https://towardsdatascience.com/uncertainty-quantification-in-machine-learning-with-an-easy-python-interface/ Wed, 26 Mar 2025 19:14:47 +0000 https://towardsdatascience.com/?p=605304 The ML Uncertainty Package

The post Uncertainty Quantification in Machine Learning with an Easy Python Interface appeared first on Towards Data Science.

]]>
Uncertainty quantification (UQ) in a Machine Learning (ML) model allows one to estimate the precision of its predictions. This is extremely important for utilizing its predictions in real-world tasks. For instance, if a machine learning model is trained to predict a property of a material, a predicted value with a 20% uncertainty (error) is likely to be used very differently from a predicted value with a 5% uncertainty (error) in the overall decision-making process. Despite its importance, UQ capabilities aren’t available with popular ML software in Python, such as scikit-learn, Tensorflow, and Pytorch.

Enter ML Uncertainty: a Python package designed to address this problem. Built on top of popular Python libraries such as SciPy and scikit-learn, ML Uncertainty provides a very intuitive interface to estimate uncertainties in ML predictions and, where possible, model parameters. Requiring only about four lines of code to perform these estimations, the package leverages powerful and theoretically rigorous mathematical methods in the background. It exploits the underlying statistical properties of the ML model in question, making the package computationally inexpensive. Moreover, this approach extends its applicability to real-world use cases where often, only small amounts of data are available.

Motivation

I have been an avid Python user for the last 10 years. I love the large number of powerful libraries that have been created and maintained, and the community, which is very active. The idea for ML Uncertainty came to me when I was working on a hybrid ML problem. I had built an ML model to predict stress-strain curves of some polymers. Stress-strain curves–an important property of polymers–obey certain physics-based rules; for instance, they have a linear region at low strain values, and the tensile modulus decreases with temperature.

I found from literature some non-linear models to describe the curves and these behaviors, thereby reducing the stress-strain curves to a set of parameters, each with some physical meaning. Then, I trained an ML model to predict these parameters from some easily measurable polymer attributes. Notably, I only had a few hundred data points, as is quite common in scientific applications. Having trained the model, finetuned the hyperparameters, and performed the outlier analysis, one of the stakeholders asked me: “This is all good, but what are the error estimates on your predictions?” And I realized that there wasn’t an elegant way to estimate this with Python. I also realized that this wasn’t going to be the last time that this problem was going to arise. And that led me down the path that culminated in this package. 

Having spent some time studying Statistics, I suspected that the math for this wasn’t impossible or even that hard. I began researching and reading up books like Introduction to Statistical Learning and Elements of Statistical Learning1,2 and found some answers there. ML Uncertainty is my attempt at implementing some of those methods in Python to integrate statistics more tightly into machine learning. I believe that the future of machine learning depends on our ability to increase the reliability of predictions and the interpretability of models, and this is a small step towards that goal. Having developed this package, I have frequently used it in my work, and it has benefited me greatly.

This is an introduction to ML Uncertainty with an overview of the theories underpinning it. I have included some equations to explain the theory, but if those are overwhelming, feel free to gloss over them. For every equation, I have stated the key idea it represents.

Getting started: An example

We often learn best by doing. So, before diving deeper, let’s consider an example. Say we are working on a good old-fashioned linear regression problem where the model is trained with scikit-learn. We think that the model has been trained well, but we want more information. For instance, what are the prediction intervals for the outputs? With ML Uncertainty, this can be done in 4 lines as shown below and discussed in this example.

Illustrating ML uncertainty code (a) and plot (b) for linear regression. Image by author.

All examples for this package can be found here: https://github.com/architdatar/ml_uncertainty/tree/main/examples.

Delving deeper: A peek under the hood

ML Uncertainty performs these computations by having the ParametricModelInference class wrap around the LinearRegression estimator from scikit-learn to extract all the information it needs to perform the uncertainty calculations. It follows the standard procedure for uncertainty estimation, which is detailed in many a statistics textbook,2 of which an overview is shown below.

Since this is a linear model that can be expressed in terms of parameters (\( \beta \)) as \( y = X\beta \), ML Uncertainty first computes the degrees of freedom for the model (\( p \)), the error degrees of freedom (\( n – p – 1 \)), and the residual sum of squares (\( \hat{\sigma}^2 \)). Then, it computes the uncertainty in the model parameters; i.e., the variance-covariance matrix.3

\( \text{Var}(\hat{\beta}) = \hat{\sigma}^2 (J^T J)^{-1} \)

Where \( J \) is the Jacobian matrix for the parameters. For linear regression, this translates to:

\( \text{Var}(\hat{\beta}) = \hat{\sigma}^2 (X^T X)^{-1} \)

Finally, the get_intervals function computes the prediction intervals by propagating the uncertainties in both inputs as well as the parameters. Thus, for data \( X^* \) where predictions and uncertainties are to be estimated, predictions \( \hat{y^*} \) along with the \( (1 – \alpha) \times 100\% \) prediction interval are:

\( \hat{y^*} \pm t_{1 – \alpha/2, n – p – 1} \, \hat{\sigma} \sqrt{\text{Var}(\hat{y^*})} \)

Where,

\( \text{Var}(\hat{y^*}) = (\nabla_X f)(\delta X^*)^2(\nabla_X f)^T + (\nabla_\beta f)(\delta \hat{\beta})^2(\nabla_\beta f)^T + \hat{\sigma}^2 \)

In English, this means that the uncertainty in the output depends on the uncertainty in the inputs, uncertainty in the parameters, and the residual uncertainty. Simplified for a multiple linear model and assuming no uncertainty in inputs, this translates to:

\( \text{Var}(\hat{y^*}) = \hat{\sigma}^2 \left(1 + X^* (X^T X)^{-1} X^{*T} \right) \)

Extensions to linear regression

So, this is what goes on under the hood when those four lines of code are executed for linear regression. But this isn’t all. ML Uncertainty comes equipped with two more powerful capabilities:

  1. Regularization: ML Uncertainty supports L1, L2, and L1+L2 regularization. Combined with linear regression, this means that it can cater to LASSO, ridge, and elastic net regressions. Check out this example.
  2. Weighted least squares regression: Sometimes, not all observations are equal. We might want to give more weight to some observations and less weight to others. Commonly, this happens in science when some observations have a high amount of uncertainty while some are more precise. We want our regression to reflect the more precise ones, but cannot fully discard the ones with high uncertainty. For such cases, the weighted least squares regression is used.

Most importantly, a key assumption of linear regression is something known as homoscedasticity; i.e., that the samples of the response variables are drawn from populations with similar variances. If this is not the case, it is handled by assigning weights to responses depending on the inverse of their variance. This can be easily handled in ML Uncertainty by simply specifying the sample weights to be used during training in the y_train_weights parameter of the ParametricModelInference class, and the rest will be handled. An application of this is shown in this example, albeit for a nonlinear regression case.

Basis expansions

I am always fascinated by how much ML we can get done by just doing linear regression properly. Many kinds of data such as trends, time series, audio, and images, can be represented by basis expansions. These representations behave like linear models with many amazing properties. ML Uncertainty can be used to compute uncertainties for these models easily. Check out these examples called spline_synthetic_data, spline_wage_data, and fourier_basis.

Results of ML Uncertainty used for weighted least squares regression, B-Spline basis with synthetic data, B-Spline basis with wage data, and Fourier basis. Image by author.

Beyond linear regression

We often encounter situations where the underlying model cannot be expressed as a linear model. This commonly occurs in science, for instance, when complex reaction kinetics, transport phenomena, process control problems, are modeled. Standard Python packages like scikit-learn, etc., don’t allow one to directly fit these non-linear models and perform uncertainty estimation on them. ML Uncertainty ships with a class called NonLinearRegression capable of handling non-linear models. The user can specify the model to be fit and the class handles fitting with a scikit-learn-like interface which uses a SciPy least_squares function in the background. This can be easily integrated with the ParametericModelInference class for seamless uncertainty estimation. Like linear regression, we can handle weighted least squares and regularization for non-linear regression. Here is an example.

Random Forests

Random Forests have gained significant popularity in the field. They operate by averaging the predictions of decision trees. Decision trees, in turn, identify a set of rules to divide the predictor variable space (input space) and assign a response value to each terminal node (leaf). The predictions from decision trees are averaged to provide a prediction for the random forest.1 They are particularly useful because they can identify complex relationships in data, are accurate, and make fewer assumptions about the data than regressions do.

While it is implemented in popular ML libraries like scikit-learn, there is no straightforward way to estimate prediction intervals. This is particularly important for regression as random forests, given their high flexibility, tend to overfit their training data. Since random forests doesn’t have parameters like traditional regression models do, uncertainty quantification needs to be performed differently. 

We use the basic idea of estimating prediction intervals using bootstrapping as described by Hastie et al. in Chapter 7 of their book Elements of Statistical Learning.2 The central idea we can exploit is that the variance of the predictions \( S(Z) \) for some data \( Z \) can be estimated via predictions of its bootstrap samples as follows:

\( \widehat{\text{Var}}[S(Z)] = \frac{1}{B – 1} \sum_{b=1}^{B} \left( S(Z^{*b}) – \bar{S}^{*} \right)^2 \)

Where \( \bar{S}^{*} = \sum_b S(Z^{*b}) / B \). Bootstrap samples are samples drawn from the original dataset repeatedly and independently, thereby allowing repetitions. Lucky for us, random forests are trained using one bootstrap sample for each decision tree within it. So, the prediction from each tree results in a distribution whose variance gives us the variance of the prediction. But there is still one problem. Let’s say we want to obtain the variance in prediction for the \( i^{\text{th}} \) training sample. If we simply use the formula above, some predictions will be from trees that include the \( i^{\text{th}} \) sample in the bootstrap sample on which they are trained. This could lead to an unrealistically smaller variance estimate.

To solve this problem, the algorithm implemented in ML Uncertainty only considers predictions from trees which did not use the \( i^{\text{th}} \) sample for training. This results in an unbiased estimate of the variance.

The beautiful thing about this approach is that we don’t need any additional re-training steps. Instead, the EnsembleModelInference class elegantly wraps around the RandomForestRegressor estimator in scikit-learn and obtains all the necessary information from it.

This method is benchmarked using the method described in Zhang et al.,4 which states that a correct \( (1 – \alpha) \times 100\% \) prediction interval is one for which the probability of it containing the observed response is \( (1 – \alpha) \times 100\% \). Mathematically,

\( P(Y \in I_{\alpha}) \approx 1 – \alpha \)

Here is an example to see ML Uncertainty in action for random forest models.

Uncertainty propagation (Error propagation)

How much does a certain amount of uncertainty in input variables and/or model parameters affect the uncertainty in the response variable? How does this uncertainty (epistemic) compare to the inherent uncertainty in the response variables (aleatoric uncertainty)? Often, it is important to answer these questions to decide on the course of action. For instance, if one finds that the uncertainty in model parameters contributes highly to the uncertainty in predictions, one could collect more data or investigate alternative models to reduce this uncertainty. Conversely, if the epistemic uncertainty is smaller than the aleatoric uncertainty, trying to reduce it further might be pointless. With ML uncertainty, these questions can be answered easily.

Given a model relating the predictor variables to the response variable, the ErrorPropagation class can easily compute the uncertainty in responses. Say the responses (\( y \)) are related to the predictor variables (\( X \)) via some function (\( f \)) and some parameters (\( \beta \)), expressed as:

\( y = f(X, \beta) \).

We wish to obtain prediction intervals for responses (\( \hat{y^*} \)) for some predictor data (\( X^* \)) with model parameters estimated as \( \hat{\beta} \). The uncertainty in \( X^* \) and \( \hat{\beta} \) are given by \( \delta X^* \) and \( \delta \hat{\beta} \), respectively. Then, the \( (1 – \alpha) \times 100\% \) prediction interval of the response variables will be given as:

\( \hat{y^*} \pm t_{1 – \alpha/2, n – p – 1} \, \hat{\sigma} \sqrt{\text{Var}(\hat{y^*})} \)

Where,

\( \text{Var}(\hat{y^*}) = (\nabla_X f)(\delta X^*)^2(\nabla_X f)^T + (\nabla_\beta f)(\delta \hat{\beta})^2(\nabla_\beta f)^T + \hat{\sigma}^2 \)

The important thing here is to notice how the uncertainty in predictions includes contributions from the inputs, parameters, as well as the inherent uncertainty of the response.

The ability of the ML Uncertainty package to propagate both input and parameter uncertainties makes it very handy, particularly in science, where we strongly care about the error (uncertainty) in each value being predicted. Consider the often talked about concept of hybrid machine learning. Here, we model known relationships in data through first principles and unknown ones using black-box models. Using ML Uncertainty, the uncertainties obtained from these different methods can be easily propagated through the computation graph.

A very simple example is that of the Arrhenius model for predicting reaction rate constants. The formula \( k = Ae^{-E_a / RT} \) is very well-known. Say, the parameters \( A, E_a \) were predicted from some ML model and have an uncertainty of 5%. We wish to know how much error that translates to in the reaction rate constant.

This can be very easily accomplished with ML Uncertainty as shown in this example.

Illustration of uncertainty propagation through computational graph. Image by author.

Limitations

As of v0.1.1, ML Uncertainty only works for ML models trained with scikit-learn. It supports the following ML models natively: random forest, linear regression, LASSO regression, ridge regression, elastic net, and regression splines. For any other models, the user can create the model, the residual, loss function, etc., as shown for the non-linear regression example. The package has not been tested for neural networks, transformers, and other deep learning models.

Contributions from the open-source ML community are welcome and highly appreciated. While there is much to be done, some key areas of effort are adapting ML Uncertainty to other frameworks such as PyTorch and Tensorflow, adding support for other ML models, highlighting issues, and improving documentation.

Benchmarking

The ML Uncertainty code has been benchmarked against the statsmodels package in Python. Specific cases can be found here.

Background

Uncertainty quantification in machine learning has been studied in the ML community and there is growing interest in this field. However, as of now, the existing solutions are applicable to very specific use cases and have key limitations.

For linear models, the statsmodels library can provide UQ capabilities. While theoretically rigorous, it cannot handle non-linear models. Moreover, the model needs to be expressed in a format specific to the package. This means that the user cannot take advantage of the powerful preprocessing, training, visualization, and other capabilities provided by ML packages like scikit-learn. While it can provide confidence intervals based on uncertainty in the model parameters, it cannot propagate uncertainty in predictor variables (input variables).

Another family of solutions is model-agnostic UQ. These solutions utilize subsamples of training data, train the model repeatedly based on it, and use these results to estimate prediction intervals. While sometimes useful in the limit of large data, these techniques may not provide accurate estimates for small training datasets where the samples chosen might lead to substantially different estimates. Moreover, it is a computationally expensive exercise since the model needs to be retrained multiple times. Some packages using this approach are MAPIE, PUNCC, UQPy, and ml_uncertainty by NIST (same name, different package), among many others.5–8

With ML Uncertainty, the goals have been to keep the training of the model and its UQ separate, cater to more generic models beyond linear regression, exploit the underlying statistics of the models, and avoid retraining the model multiple times to make it computationally inexpensive.

Summary and future work

This was an introduction to ML Uncertainty—a Python software package to easily compute uncertainties in machine learning. The main features of this package have been introduced here and some of the philosophy of its development has been discussed. More detailed documentation and theory can be found in the docs. While this is only a start, there is immense scope to expand this. Questions, discussions, and contributions are always welcome. The code can be found on GitHub and the package can be installed from PyPi. Give it a try with pip install ml-uncertainty.

References

(1) James, G.; Witten, D.; Hastie, T.; Tibshirani, R. An Introduction to Statistical Learning; Springer US: New York, NY, 2021. https://doi.org/10.1007/978-1-0716-1418-1.

(2) Hastie, T.; Tibshirani, R.; Friedman, J. The Elements of Statistical Learning; Springer New York: New York, NY, 2009. https://doi.org/10.1007/978-0-387-84858-7.

(3) Börlin, N. Nonlinear Optimization. https://www8.cs.umu.se/kurser/5DA001/HT07/lectures/lsq-handouts.pdf.

(4) Zhang, H.; Zimmerman, J.; Nettleton, D.; Nordman, D. J. Random Forest Prediction Intervals. Am Stat 2020, 74 (4), 392–406. https://doi.org/10.1080/00031305.2019.1585288.

(5) Cordier, T.; Blot, V.; Lacombe, L.; Morzadec, T.; Capitaine, A.; Brunel, N. Flexible and Systematic Uncertainty Estimation with Conformal Prediction via the MAPIE Library. In Conformal and Probabilistic Prediction with Applications; 2023.

(6) Mendil, M.; Mossina, L.; Vigouroux, D. PUNCC: A Python Library for Predictive Uncertainty and Conformalization. In Proceedings of the Twelfth Symposium on Conformal and Probabilistic Prediction with Applications; Papadopoulos, H., Nguyen, K. A., Boström, H., Carlsson, L., Eds.; Proceedings of Machine Learning Research; PMLR, 2023; Vol. 204, pp 582–601.

(7) Tsapetis, D.; Shields, M. D.; Giovanis, D. G.; Olivier, A.; Novak, L.; Chakroborty, P.; Sharma, H.; Chauhan, M.; Kontolati, K.; Vandanapu, L.; Loukrezis, D.; Gardner, M. UQpy v4.1: Uncertainty Quantification with Python. SoftwareX 2023, 24, 101561. https://doi.org/10.1016/j.softx.2023.101561.

(8) Sheen, D. Machine Learning Uncertainty Estimation Toolbox. https://github.com/usnistgov/ml_uncertainty_py.

\[\]

The post Uncertainty Quantification in Machine Learning with an Easy Python Interface appeared first on Towards Data Science.

]]>
Attractors in Neural Network Circuits: Beauty and Chaos https://towardsdatascience.com/attractors-in-neural-network-circuits-beauty-and-chaos/ Tue, 25 Mar 2025 19:26:57 +0000 https://towardsdatascience.com/?p=605254 Neural networks under a different lens: generating basins of attraction in a shift register NN

The post Attractors in Neural Network Circuits: Beauty and Chaos appeared first on Towards Data Science.

]]>
The state space of the first two neuron activations over time follows an attractor.

What is one thing in common between memories, oscillating chemical reactions and double pendulums? All these systems have a basin of attraction for possible states, like a magnet that draws the system towards certain trajectories. Complex systems with multiple inputs usually evolve over time, generating intricate and sometimes chaotic behaviors. Attractors represent the long-term behavioral pattern of dynamical systems — a pattern to which a system converges over time regardless of its initial conditions. 

Neural networks have become ubiquitous in our current Artificial Intelligence era, typically serving as powerful tools for representation extraction and pattern recognition. However, these systems can also be viewed through another fascinating lens: as dynamical systems that evolve and converge to a manifold of states over time. When implemented with feedback loops, even simple neural networks can produce strikingly beautiful attractors, ranging from limit cycles to chaotic structures.

Neural Networks as Dynamical Systems

While neural networks in general sense are most commonly known for embedding extraction tasks, they can also be viewed as dynamical systems. A dynamical system describes how points in a state space evolve over time according to a fixed set of rules or forces. In the context of neural networks, the state space consists of the activation patterns of neurons, and the evolution rule is determined by the network’s weights, biases, activation functions, and other tricks.

Traditional NNs are optimized via gradient descent to find its endstate of convergence. However, when we introduce feedback — connecting the output back to the input — the network becomes a recurrent system with a different kind of temporal dynamic. These dynamics can exhibit a wide range of behaviors, from simple convergence to a fixed point to complex chaotic patterns.

Understanding Attractors

An attractor is a set of states toward which a system tends to evolve from a wide variety of starting conditions. Once a system reaches an attractor, it remains within that set of states unless perturbed by an external force. Attractors are indeed deeply involved in forming memories [1], oscillating chemical reactions [2], and other nonlinear dynamical systems. 

Types of Attractors

Dynamical Systems can exhibit several types of attractors, each with distinct characteristics:

  • Point Attractors: the simplest form, where the system converges to a single fixed point regardless of starting conditions. This represents a stable equilibrium state.
  • Limit Cycles: the system settles into a repeating periodic orbit, forming a closed loop in phase space. This represents oscillatory behavior with a fixed period.
  • Toroidal (Quasiperiodic) Attractors: the system follows trajectories that wind around a donut-like structure in the phase space. Unlike limit cycles, these trajectories never really repeat but they remain bound to a specific region.
  • Strange (Chaotic) Attractors: characterized by aperiodic behavior that never repeats exactly yet remains bounded within a finite region of phase space. These attractors exhibit sensitive dependence on initial conditions, where a tiny difference will introduce significant consequences over time — a hallmark of chaos. Think butterfly effect.

Setup

In the following section, we will dive deeper into an example of a very simple NN architecture capable of said behavior, and demonstrate some pretty examples. We will touch on Lyapunov exponents, and provide implementation for those who wish to experiment with generating their own Neural Network attractor art (and not in the generative AI sense).

Figure 1. NN schematic and components that we will use for the attractor generation. [all figures are created by the author, unless stated otherwise]

We will use a grossly simplified one-layer NN with a feedback loop. The architecture consists of:

  1. Input Layer:
    • Array of size D (here 16-32) inputs
    • We will unconventionally label them as y₁, y₂, y₃, …, yD to highlight that these are mapped from the outputs
    • Acts as a shift register that stores previous outputs
  2. Hidden Layer:
    • Contains N neurons (here fewer than D, ~4-8)
    • We will label them x₁, x₂, …, xN
    • tanh() activation is applied for squashing
  3. Output Layer
    • Single output neuron (y₀)
    • Combines the hidden layer outputs with biases — typically, we use biases to offset outputs by adding them; here, we used them for scaling, so they are factually an array of weights
  4. Connections:
    • Input to Hidden: Weight matrix w[i,j] (randomly initialized between -1 and 1)
    • Hidden to Output: Bias weights b[i] (randomly initialized between 0 and s)
  5. Feedback Loop:
    • The output y₀ is fed back to the input layer, creating a dynamic map
    • Acts as a shift register (y₁ = previous y₀, y₂ = previous y₁, etc.)
    • This feedback is what creates the dynamical system behavior
  6. Key Formulas:
    • Hidden layer: u[i] = Σ(w[i,j] * y[j]); x[i] = tanh(u[i])
    • Output: y₀ = Σ(b[i] * x[i])

The critical aspects that make this network generate attractors:

  • The feedback loop turns a simple feedforward network into a dynamical system
  • The nonlinear activation function (tanh) enables complex behaviors
  • The random weight initialization (controlled by the random seed) creates different attractor patterns
  • The scaling factor s affects the dynamics of the system and can push it into chaotic regimes

In order to investigate how prone the system is to chaos, we will calculate the Lyapunov exponents for different sets of parameters. Lyapunov exponent is a measure of the instability of a dynamical system

\[\delta Z(t)| \approx e^{\lambda t} |\delta (Z(0))|\]

\[\lambda = n_t \sum_{k=0}^{n_t-1} ln \frac{|\Delta y_{k+1}|}{|\Delta y_k|}\]

…where nt​ is a number of time steps, Δyk ​is a distance between the states y(xi) and y(xi+ϵ) at a point in time; ΔZ(0) represents an initial infinitesimal (very small) separation between two nearby starting points, and ΔZ(t) is the separation after time t. For stable systems converging to a fixed point or a stable attractor this parameter is less than 0, for unstable (diverging, and, therefore, chaotic systems) it is greater than 0.

Let’s code it up! We will only use NumPy and default Python libraries for the implementation.

import numpy as np
from typing import Tuple, List, Optional


class NeuralAttractor:
    """
    
    N : int
        Number of neurons in the hidden layer
    D : int
        Dimension of the input vector
    s : float
        Scaling factor for the output

    """
    
    def __init__(self, N: int = 4, D: int = 16, s: float = 0.75, seed: Optional[int] = 
None):
        self.N = N
        self.D = D
        self.s = s
        
        if seed is not None:
            np.random.seed(seed)
        
        # Initialize weights and biases
        self.w = 2.0 * np.random.random((N, D)) - 1.0  # Uniform in [-1, 1]
        self.b = s * np.random.random(N)  # Uniform in [0, s]
        
        # Initialize state vector structures
        self.x = np.zeros(N)  # Neuron states
        self.y = np.zeros(D)  # Input vector

We initialize the NeuralAttractor class with some basic parameters — number of neurons in the hidden layer, number of elements in the input array, scaling factor for the output, and random seed. We proceed to initialize the weights and biases randomly, and x and y states. These weights and biases will not be optimized — they will stay put, no gradient descent this time.

    def reset(self, init_value: float = 0.001):
        """Reset the network state to initial conditions."""
        self.x = np.ones(self.N) * init_value
        self.y = np.zeros(self.D)
        
    def iterate(self) -> np.ndarray:
        """
        Perform one iteration of the network and return the neuron outputs.
        
        """
        # Calculate the output y0
        y0 = np.sum(self.b * self.x)
        
        # Shift the input vector
        self.y[1:] = self.y[:-1]
        self.y[0] = y0
        
        # Calculate the neuron inputs and apply activation fn
        for i in range(self.N):
            u = np.sum(self.w[i] * self.y)
            self.x[i] = np.tanh(u)
            
        return self.x.copy()

Next, we will define the iteration logic. We start every iteration with the feedback loop — we implement the shift register circuit by shifting all y elements to the right, and compute the most recent y0 output to place it into the first element of the input.

    def generate_trajectory(self, tmax: int, discard: int = 0) -> Tuple[np.ndarray, 
np.ndarray]:
        """
        Generate a trajectory of the states for tmax iterations.
        
        -----------
        tmax : int
            Total number of iterations
        discard : int
            Number of initial iterations to discard

        """
        self.reset()
        
        # Discard initial transient
        for _ in range(discard):
            self.iterate()
        
        x1_traj = np.zeros(tmax)
        x2_traj = np.zeros(tmax)
        
        for t in range(tmax):
            x = self.iterate()
            x1_traj[t] = x[0]
            x2_traj[t] = x[1]
            
        return x1_traj, x2_traj

Now, we define the function that will iterate our network map over the tmax number of time steps and output the states of the first two hidden neurons for visualization. We can use any hidden neurons, and we could even visualize 3D state space, but we will limit our imagination to two dimensions.

This is the gist of the system. Now, we will just define some line and segment magic for pretty visualizations.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import matplotlib.path as mpath
from typing import Tuple, Optional, Callable


def make_segments(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """
    Create list of line segments from x and y coordinates.
    
    -----------
    x : np.ndarray
        X coordinates
    y : np.ndarray
        Y coordinates

    """
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments


def colorline(
    x: np.ndarray,
    y: np.ndarray,
    z: Optional[np.ndarray] = None,
    cmap = plt.get_cmap("jet"),
    norm = plt.Normalize(0.0, 1.0),
    linewidth: float = 1.0,
    alpha: float = 0.05,
    ax = None
):
    """
    Plot a colored line with coordinates x and y.
    
    -----------
    x : np.ndarray
        X coordinates
    y : np.ndarray
        Y coordinates

    """
    if ax is None:
        ax = plt.gca()
        
    if z is None:
        z = np.linspace(0.0, 1.0, len(x))
    
    segments = make_segments(x, y)
    lc = mcoll.LineCollection(
        segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha
    )
    ax.add_collection(lc)
    
    return lc


def plot_attractor_trajectory(
    x: np.ndarray,
    y: np.ndarray,
    skip_value: int = 16,
    color_function: Optional[Callable] = None,
    cmap = plt.get_cmap("Spectral"),
    linewidth: float = 0.1,
    alpha: float = 0.1,
    figsize: Tuple[float, float] = (10, 10),
    interpolate_steps: int = 3,
    output_path: Optional[str] = None,
    dpi: int = 300,
    show: bool = True
):
    """
    Plot an attractor trajectory.
    
    Parameters:
    -----------
    x : np.ndarray
        X coordinates
    y : np.ndarray
        Y coordinates
    skip_value : int
        Number of points to skip for sparser plotting

    """
    fig, ax = plt.subplots(figsize=figsize)
    
    if interpolate_steps > 1:
        path = mpath.Path(np.column_stack([x, y]))
        verts = path.interpolated(steps=interpolate_steps).vertices
        x, y = verts[:, 0], verts[:, 1]
    
    x_plot = x[::skip_value]
    y_plot = y[::skip_value]
    
    if color_function is None:
        z = abs(np.sin(1.6 * y_plot + 0.4 * x_plot))
    else:
        z = color_function(x_plot, y_plot)
    
    colorline(x_plot, y_plot, z, cmap=cmap, linewidth=linewidth, alpha=alpha, ax=ax)
    
    ax.set_xlim(x.min(), x.max())
    ax.set_ylim(y.min(), y.max())
    
    ax.set_axis_off()
    ax.set_aspect('equal')
    
    plt.tight_layout()
    
    if output_path:
        fig.savefig(output_path, dpi=dpi, bbox_inches='tight')

    return fig

The functions written above will take the generated state space trajectories and visualize them. Because the state space may be densely filled, we will skip every 8th, 16th or 32th time point to sparsify our vectors. We also don’t want to plot these in one solid color, therefore we are coding the color as a periodic function (np.sin(1.6 * y_plot + 0.4 * x_plot)) based on the x and y coordinates of the figure axis. The multipliers for the coordinates are arbitrary and happen to generate nice smooth color maps, to your liking.

N = 4
D = 32
s = 0.22
seed=174658140

tmax = 100000
discard = 1000

nn = NeuralAttractor(N, D, s, seed=seed)

# Generate trajectory
x1, x2 = nn.generate_trajectory(tmax, discard)

plot_attractor_trajectory(
    x1, x2,
    output_path='trajectory.png',
)

After defining the NN and iteration parameters, we can generate the state space trajectories. If we spend enough time poking around with parameters, we will find something cool (I promise!). If manual parameter grid search labor is not exactly our thing, we could add a function that checks what proportion of the state space is covered over time. If after t = 100,000 iterations (except the initial 1,000 “warm up” time steps) we only touched a narrow range of values of the state space, we are likely stuck in a point. Once we found an attractor that is not so shy to take up more state space, we can plot it using default plotting params:

Figure 2. Limit cycle attractor.

One of the stable types of attractors is the limit cycle attractor (parameters: N = 4, D = 32, s = 0.22, seed = 174658140). It looks like a single, closed loop trajectory in phase space. The orbit follows a regular, periodic path over time series. I will not include the code for Lyapunov exponent calculation here to focus on the visual aspect of the generated attractors more, but one can find it under this link, if interested. The Lyapunov exponent for this attractor (λ=−3.65) is negative, indicating stability: mathematically, this exponent will lead to the state of the system decaying, or converging, to this basin of attraction over time.

If we keep increasing the scaling factor, we are more likely to tune up the values in the circuit, and perhaps more likely to find something interesting.

Figure 3. Toroidal attractor.

Here is the toroidal (quasiperiodic) attractor (parameters: N = 4, D = 32, s = 0.55, seed = 3160697950). It still has an ordered structure of sheets that wrap around in organized, quasiperiodic patterns. The Lyapunov exponent for this attractor has a higher value, but is still negative (λ=−0.20).

As we further increase the scaling factor s, the system becomes more prone to chaos. The strange (chaotic) attractor emerges with the following parameters: N = 4, D = 16, s = 1.4, seed = 174658140). It is characterized by an erratic, unpredictable pattern of trajectories that never repeat. The Lyapunov exponent for this attractor is positive (λ=0.32), indicating instability (divergence from an initially very close state over time) and chaotic behavior. This is the “butterfly effect” attractor.

Figure 4. Strange attractor.

As we further increase the scaling factor s, the system becomes more prone to chaos. The strange (chaotic) attractor emerges with the following parameters: N = 4, D = 16, s = 1.4, seed = 174658140. It is characterized by an erratic, unpredictable pattern of trajectories that never repeat. The Lyapunov exponent for this attractor is positive (λ=0.32), indicating instability (divergence from an initially very close state over time) and chaotic behavior. This is the “butterfly effect” attractor.

Just another confirmation that aesthetics can be very mathematical, and vice versa. The most visually compelling attractors often exist at the edge of chaos — think about it for a second! These structures are complex enough to exhibit intricate behavior, yet ordered enough to maintain coherence. This resonates with observations from various art forms, where balance between order and unpredictability often creates the most engaging experiences.

An interactive widget to generate and visualize these attractors is available here. The source code is available, too, and invites further exploration. The ideas behind this project were largely inspired by the work of J.C. Sprott [3]. 

References

[1] B. Poucet and E. Save, Attractors in Memory (2005), Science DOI:10.1126/science.1112555.

[2] Y.J.F. Kpomahou et al., Chaotic Behaviors and Coexisting Attractors in a New Nonlinear Dissipative Parametric Chemical Oscillator (2022), Complexity DOI:10.1155/2022/9350516.

[3] J.C. Sprott, Artificial Neural Net Attractors (1998), Computers & Graphics DOI:10.1016/S0097-8493(97)00089-7.

The post Attractors in Neural Network Circuits: Beauty and Chaos appeared first on Towards Data Science.

]]>
What Do Machine Learning Engineers Do? https://towardsdatascience.com/what-do-machine-learning-engineers-do/ Tue, 25 Mar 2025 07:45:20 +0000 https://towardsdatascience.com/?p=605223 Breaking down my role as a machine learning engineer

The post What Do Machine Learning Engineers Do? appeared first on Towards Data Science.

]]>
In this article, I want to explain precisely what I do as a machine learning engineer. 

The aim is to help anyone looking to enter the field gain a truthful view of what a machine learning engineer is, how we work, what we do, and what a typical day in life is like. 

I hope it can help you pinpoint if a career in machine learning is indeed for you.

What is a machine learning engineer?

Due to the rapid acceleration of the tech/AI space, a machine learning engineer is still not well-defined and varies between companies and geographies to a certain extent.

However, it generally refers to someone who:

Mixes machine learning, statistics and software engineering skills to train and deploy models into production.

At some companies, there will be a large cross-over with data scientists. Still, the main distinction between the two roles is that machine learning engineers deliver the solution into production. Often, data scientists won’t do this and focus more on helping in the model-building stage.

The need for a machine learning engineer came from the fact that models in Jupyter Notebooks have zero value. So, a role well-versed in machine learning and software engineering was needed to help bring the models “to life” and ensure they generate business value.

Because of this broad skillset, machine learning engineering is not an entry-level role, and you would typically need to be a data scientist or software engineer for a couple of years first.

So, to summarise:

  • Responsibilities: Train, build and deploy machine learning models.
  • Skills & Tech: Python, SQL, AWS, Bash/Zsh, PyTorch, Docker, Kubernetes, MLOps, Git, distributed computing (not an exhaustive list).
  • Experience: A couple of years as a data scientist or software engineer, and then up-skill yourself in the other areas.

If you want a better understanding of the different data and machine learning roles, I recommend checking out some of my previous articles.

The Difference Between ML Engineers and Data Scientists
Helping you decide whether you want to be a data scientist or machine learning engineermedium.com

Should You Become A Data Scientist, Data Analyst Or Data Engineer?
Explaining the differences and requirements between the various data rolesmedium.com

What do I do?

I work as a machine learning engineer within a cross-functional team. My squad specialises in classical machine learning and combinatorial optimisation-based problems.

Much of my work revolves around improving our machine learning models and optimisation solutions to improve the customer experience and generate financial value for the business.

The general workflow for most of my projects is as follows:

  • Idea — Someone may have an idea or hypothesis about how to improve one of our models.
  • Data — We check if the data to prove or disprove this hypothesis is readily available so we can start the research.
  • Research — If the data is available, we start building or testing this new hypothesis in the model.
  • Analysis — The results of the research stage are analysed to determine if we have improved the model.
  • Ship — The improvement is “productionised” in the codebase and goes live.

Along this process, there is a lot of interaction with other functions and roles within the team and broader company.

  • The idea phase is a collaborative discussion with a product manager who can provide business insight and any critical impacts we may have missed in the initial scoping.
  • Data, Build, and Analysis can be done in collaboration with data analysts and engineers to ensure the quality of our ETL pipelines and the use of the right data sources.
  • The research section would use the help of data scientists to use statistics and machine learning skills when looking to improve our model.
  • The ship phase is a joint effort with our dedicated software engineers, ensuring our deployment is robust and up to standard with best coding practices.

From experience, I know that this type of workflow is prevalent among machine learning engineers in numerous companies, although I am sure there are slight variations depending on where you are.

My job is also not just to write code day in and day out. I have other responsibilities, like conducting workshops, presenting to stakeholders, and mentoring more junior members.

What is the structure of machine learning teams?

Machine learning engineers work in many different ways across an organisation, but there are three distinct options, and the rest are a mix of them.

  • Embedded— In this case, machine learning engineers are embedded in cross-functional teams with analysts, product managers, software engineers and data scientists, where the team solves problems in one domain within the company. This is how I work, and I really like it because you get to pick up lots of valuable skills and abilities from other team members who are specialists in their own right.
  • Consultancy— This is the flip side, where machine learning engineers are part of an “in-house consultancy” and are part of their own team. In this scenario, the machine learning engineers work on problems based on their perceived value to the business. You are technically less specialised in this option as you may need to change the type of problems you work on.
  • Infrastructure/Platform — Instead of solving business problems directly, these machine learning engineers develop in-house tools and a deployment platform to make productionising the algorithms much easier.

All ways of working have pros and cons, and in reality, I wouldn’t say one is better than the other; it’s really a matter of personal preference. You still do exciting work, nonetheless!

What is a typical day in a life?

People online often glamourise working in tech, like it’s all coffee breaks, chats, and coding for an hour a day, and you make well over six figures.

This is definitely not the case, and I wish it was true, but it’s still a fun and enjoyable workday compared to many other professions.

My general experience has been:

  • 9:00 am — 9:30 am. Start at 9 am with a morning standup to catch up with the team regarding the previous day’s work and what you are doing today. A “standup” meeting is very common across tech.
  • 9:30 am — 10:30 am. After the standup, there may be another meeting for an hour, 9:30–10:30 or so, with stakeholders, engineers, an all-hands or other company meetings.
  • 10:30 am — 13:00 pm. Then, it’s a work/code block for two hours or so where I focus on my projects. Depending on my work, I may pair with another data scientist, machine learning engineer or software engineer.
  • 13:00 pm — 14:00 pm. Lunch.
  • 14:00 pm — 17:45 pm. Afternoons are normally free of meetings, and there is a large block of focus time to work on your projects. This is mainly for individual contributors like myself.
  • 17:45 pm — 18:00 pm. Reply to emails and Slack messages and wrap up for the day.

Every day is different, but this is what you can expect. As you can tell, it’s nothing “extrordinary.”

This is also the workday for a junior / mid-level individual contributor (IC) like myself. Senior positions, especially managerial roles, typically have more meetings.

An important thing to note is that I don’t always code in my work blocks. I may have a presentation to prepare for stakeholders, some ad-hoc analysis for our product manager, or some writing up of my latest research. I may not even code for the whole day!

On average, I spend 3–4 hours hard coding; the rest is meetings or ad-hoc work. Of course, this varies between companies and at different times of the year.

Why am I’m a machine learning engineer?

The reason I am a machine learning engineer can be boiled down to four main reasons:

  • Interesting. As a machine learning engineer, I get to be correct at the forefront of the latest tech trends like AI, LLMs, and pretty much anything that is going viral in the field. There is always something new and exciting to learn, which I love! So, if you want to constantly learn new skills and apply them, this may be a career you would be interested in.
  • Work-Life Balance. Tech jobs generally provide better work-life balance than other professions like banking, law or consulting. Most machine learning jobs are 9–6, and you can often spend a few days working from home. This flexibility allows me to pursue other passions, projects, and hobbies outside of work, such as this blog!
  • Compensation. It’s no secret that tech jobs provide some of the highest salaries. According to levelsfyi, the median salary of a machine learning engineer in the UK is £93k, which is crazy for an average value.
  • Range of Industries. As a machine learning engineer, you can work in loads of different industries during your career. However, to become a real specialist, you must find and stick to one industry you love.

I hope this article gave you more insight into machine learning, if you have any questions let me know in the comments.

Another thing!

Join my free newsletter, Dishing the Data, where I share weekly tips, insights, and advice from my experience as a practicing data scientist. Plus, as a subscriber, you’ll get my FREE Data Science Resume Template!

Dishing The Data | Egor Howell | Substack
Advice and learnings on data science, tech and entrepreneurship. Click to read Dishing The Data, by Egor Howell, a…newsletter.egorhowell.com

Connect with me

The post What Do Machine Learning Engineers Do? appeared first on Towards Data Science.

]]>
Algorithm Protection in the Context of Federated Learning  https://towardsdatascience.com/algorithm-protection-in-the-context-of-federated-learning/ Fri, 21 Mar 2025 04:32:38 +0000 https://towardsdatascience.com/?p=605195 A pragmatic look into protecting algorithms and models deployed into real-world federated analysis and learning settings in healthcare.

The post Algorithm Protection in the Context of Federated Learning  appeared first on Towards Data Science.

]]>
While working at a biotech company, we aim to advance ML & AI Algorithms to enable, for example, brain lesion segmentation to be executed at the hospital/clinic location where patient data resides, so it is processed in a secure manner. This, in essence, is guaranteed by federated learning mechanisms, which we have adopted in numerous real-world hospital settings. However, when an algorithm is already considered as a company asset, we also need means that protect not only sensitive data, but also secure algorithms in a heterogeneous federated environment.

Fig.1 High-level workflow and attack surface. Image by author

Most algorithms are assumed to be encapsulated within docker-compatible containers, allowing them to use different libraries and runtimes independently. It is assumed that there is a 3rd party IT administrator who will aim to secure patients’ data and lock the deployment environment, making it inaccessible for algorithm providers. This perspective describes different mechanisms intended to package and protect containerized workloads against theft of intellectual property by a local system administrator. 

To ensure a comprehensive approach, we will address protection measures across three critical layers:

  • Algorithm code protection: Measures to secure algorithm code, preventing unauthorized access or reverse engineering.
  • Runtime environment: Evaluates risks of administrators accessing confidential data within a containerized system.
  • Deployment environment: Infrastructure safeguards against unauthorized system administrator access.
Fig.2 Different layers of protection. Image by author

Methodology

After analysis of risks, we have identified two protection measures categories:

  • Intellectual property theft and unauthorized distribution: preventing administrator users from accessing, copying, executing the algorithm. 
  • Reverse engineering risk reduction: blocking administrator users from analyzing code to uncover and claim ownership.

While understanding the subjectivity of this assessment, we have considered both qualitative and quantitative characteristics of all mechanisms.

Qualitative assessment

Categories mentioned were considered when selecting suitable solution and are considered in summary:

  • Hardware dependency: potential lock-in and scalability challenges in federated systems.
  • Software dependency: reflects maturity and long-term stability
  • Hardware and Software dependency: measures setup complexity, deployment and maintenance effort
  • Cloud dependency: risks of lock-in with a single cloud hypervisor
  • Hospital environment: evaluates technology maturity and requirements heterogeneous hardware setups.
  • Cost: covers for dedicated hardware, implementation and maintenance

Quantitative assessment

Subjective risk reduction quantitative assessment description:

Considering the above methodology and assessment criteria, we came up with a list of mechanisms that have the potential to guarantee the objective. 

Confidential containers

Confidential Containers (CoCo) is an emerging CNCF technology that aims to deliver confidential runtime environments that will run CPU and GPU workloads while protecting the algorithm code and data from the hosting company.

CoCo supports multiple TEE, including Intel TDX/SGX and AMD SEV hardware technologies, including extensions of NVidia GPU operators, that use hardware-backed protection of code and data during its execution, preventing scenarios in which a determined and skillful local administrator uses a local debugger to dump the contents of the container memory and has access to both the algorithm and data being processed. 

Trust is built using cryptographic attestation of runtime environment and code that is executed. It makes sure the code is not tempered with nor read by remote admin.

This appears to be a perfect fit for our problem, as the remote data site admin would not be able to access the algorithm code. Unfortunately, the current state of the CoCo software stack, despite continuous efforts, still suffers from security gaps that enable the malicious administrators to issue attestation for themselves and effectively bypass all the other protection mechanisms, rendering all of them effectively useless. Each time the technology gets closer to practical production readiness, a new fundamental security issue is discovered that needs to be addressed. It is worth noting that this community is fairly transparent in communicating gaps. 

The often and rightfully recognized additional complexity introduced by TEEs and CoCo (specialized hardware, configuration burden, runtime overhead due to encryption) would be justifiable if the technology delivered on its promise of code protection. While TEE seems to be well adopted, CoCo is close but not there yet and based on our experiences the horizon keeps on moving, as new fundamental vulnerabilities are discovered and need to be addressed.

In other words, if we had production-ready CoCo, it would have been a solution to our problem. 

Host-based container image encryption at rest (protection at rest and in transit)

This strategy is based on end-to-end protection of container images containing the algorithm.

It protects the source code of the algorithm at rest and in transit but does not protect it at runtime, as the container needs to be decrypted prior to the execution.

The malicious administrator at the site has direct or indirect access to the decryption key, so he can read container contents just after it is decrypted for the execution time. 

Another attack scenario is to attach a debugger to the running container image.

So host-based container image encryption at rest makes it harder to steal the algorithm from a storage device and in transit due to encryption, but moderately skilled administrators can decrypt and expose the algorithm.

In our opinion, the increased practical effort of decrypting the algorithm (time, effort, skillset, infrastructure) from the container by the administrator who has access to the decryption key is too low to be considered as a valid algorithm protection mechanism.

Prebaked custom virtual machine

In this scenario the algorithm owner is delivering an encrypted virtual machine.

The key can be added at boot time from the keyboard by someone else than admin (required at each reboot), from external storage (USB Key, very vulnerable, as anyone with physical access can attach the key storage), or using a remote SSH session (using Dropbear for instance) without allowing local admin to unlock the bootloader and disk.

Effective and established technologies such as LUKS can be used to fully encrypt local VM filesystems including bootloader.

However, even if the remote key is provided using a boot-level tiny SSH session by someone other than a malicious admin, the runtime is exposed to a hypervisor-level debugger attack, as after boot, the VM memory is decrypted and can be scanned for code and data.

Still, this solution, especially with remotely provided keys by the algorithm owner, provides significantly increased algorithm code protection compared to encrypted containers because an attack requires more skills and determination than just decrypting the container image using a decryption key. 

To prevent memory dump analysis, we considered deploying a prebaked host machine with ssh possessed keys at boot time, this removes any hypervisor level access to memory. As a side note, there are methods to freeze physical memory modules to delay loss of data.

Distroless container images

Distroless container images are reducing the number of layers and components to a minimum required to run the algorithm.

The attack surface is greatly reduced, as there are fewer components prone to vulnerabilities and known attacks. They are also lighter in terms of storage, network transmission, and latency.

However, despite these improvements, the algorithm code is not protected at all. 

Distroless containers are recommended as more secure containers but not the containers that protect the algorithm, as the algorithm is there, container image can be easily mounted and algorithm can be stolen without a significant effort.

Being distroless does not address our goal of protecting the algorithm code.

Compiled algorithm

Most machine learning algorithms are written in Python. This interpreted language makes it really easy not only to execute the algorithm code on other machines and in other environments but also to access source code and be able to modify the algorithm.

The potential scenario even enables the party that steals the algorithm code to modify it, let’s say 30% or more of the source code, and claim it’s no longer the original algorithm, and could even make a legal action much harder to provide evidence of intellectual property infringement.

Compiled languages, such as C, C++, Rust, when combined with strong compiler optimization (-O3 in the case of C, linker-time optimizations), make the source code not only unavailable as such, but also much harder to reverse engineer source code. 

Compiler optimizations introduce significant control flow changes, mathematical operations substitutions, function inlining, code restructuring, and difficult stack tracing.

This makes it much harder to reverse engineer the code, making it a practically infeasible option in some scenarios, thus it can be considered as a way to increase the cost of reverse engineering attack by orders of magnitude compared to plain Python code.

There’s an increased complexity and skill gap, as most of the algorithms are written in Python and would have to be converted to C, C++ or Rust.

This option does increase the cost of further development of the algorithm and even modifying it to make a claim of its ownership but it does not prevent the algorithm from being executed outside of the agreed contractual scope.

Code obfuscation

The established technique of making the code much less readable, harder to understand and develop further can be used to make algorithm evolutions much harder.

Unfortunately, it does not prevent the algorithm from being executed outside of contractual scope.

Also, the de-obfuscation technologies are getting much better, thanks to advanced language models, lowering the practical effectiveness of code obfuscation.

Code obfuscation does increase the practical cost of algorithm reverse engineering, so it’s worth considering as an option combined with other options (for instance, with compiled code and custom VMs).

Homomorphic Encryption as code protection mechanism

Homomorphic Encryption (HE) is a promised technology aimed at protecting the data, very interesting from secure aggregation strategies of partial results in Federated Learning and analytics scenarios. 

The aggregation party (with limited trust) can only process encrypted data and perform encrypted aggregations, then it can decrypt aggregated results without being able to decrypt any individual data.

Practical applications of HE are limited due to its complexity, performance hits, limited number of supported operations, there’s observable progress (including GPU acceleration for HE) but still it’s a niche and emerging data protection technique.

From an algorithm protection goal perspective, HE is not designed, nor can be made to protect the algorithm. So it’s not an algorithm protection mechanism at all.

Conclusions

Fig.3 Risk reduction scores, Image by author

In essence, we described and assessed strategies and technologies to protect algorithm IP and sensitive data in the context of deploying Medical Algorithms and running them in potentially untrusted environments, such as hospitals.

What’s visible, the most promising technologies are those that provide a degree of hardware isolation. However those make an algorithm provider completely dependent on the runtime it will be deployed. While compilation and obfuscation do not mitigate completely the risk of intellectual property theft, especially even basic LLM seem to be helpful, those methods, especially when combined, make algorithms very difficult, thus expensive, to use and modify the code. Which would already provide a degree of security.

Prebaked host/virtual machines are the most common and adopted methods, extended with features like full disk encryption with keys acquired during boot via SSH, which could make it fairly difficult for local admin to access any data. However, especially pre-baked machines could cause certain compliance concerns at the hospital, and this needs to be assessed prior to establishing a federated network. 

Key Hardware and Software vendors(Intel, AMD, NVIDIA, Microsoft, RedHat) recognized significant demand and continue to evolve, which gives a promise that training IP-protected algorithms in a federated manner, without disclosing patients’ data, will soon be within reach. However, hardware-supported methods are very sensitive to hospital internal infrastructure, which by nature is quite heterogeneous. Therefore, containerisation provides some promise of portability. Considering this, Confidential Containers technology seems to be a very tempting promise provided by collaborators, while it’s still not fullyproduction-readyy.

Certainly combining above mechanisms, code, runtime and infrastructure environment supplemented with proper legal framework decrease residual risks, however no solution provides absolute protection particularly against determined adversaries with privileged access – the combined effect of these measures creates substantial barriers to intellectual property theft. 

We deeply appreciate and value feedback from the community helping to further steer future efforts to develop sustainable, secure and effective methods for accelerating AI development and deployment. Together, we can tackle these challenges and achieve groundbreaking progress, ensuring robust security and compliance in various contexts. 

Contributions: The author would like to thank Jacek Chmiel, Peter Fernana Richie, Vitor Gouveia and the Federated Open Science team at Roche for brainstorming, pragmatic solution-oriented thinking, and contributions.

Link & Resources

Intel Confidential Containers Guide 

Nvidia blog describing integration with CoCo Confidential Containers Github & Kata Agent Policies

Commercial Vendors: Edgeless systems contrast, Redhat & Azure

Remote Unlock of LUKS encrypted disk

A perfect match to elevate privacy-enhancing healthcare analytics

Differential Privacy and Federated Learning for Medical Data

The post Algorithm Protection in the Context of Federated Learning  appeared first on Towards Data Science.

]]>