Towards Data Science https://towardsdatascience.com/ The world’s leading publication for data science, AI, and ML professionals. Sat, 12 Apr 2025 01:11:24 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Towards Data Science https://towardsdatascience.com/ 32 32 Sesame  Speech Model:  How This Viral AI Model Generates Human-Like Speech https://towardsdatascience.com/sesame-speech-model-how-this-viral-ai-model-generates-human-like-speech/ Sat, 12 Apr 2025 01:09:27 +0000 https://towardsdatascience.com/?p=605722 A deep dive into residual vector quantizers, conversational speech AI, and talkative transformers.

The post Sesame  Speech Model:  How This Viral AI Model Generates Human-Like Speech appeared first on Towards Data Science.

]]>
Recently, Sesame AI published a demo of their latest Speech-to-Speech model. A conversational AI agent who is really good at speaking, they provide relevant answers, they speak with expressions, and honestly, they are just very fun and interactive to play with.

Note that a technical paper is not out yet, but they do have a short blog post that provides a lot of information about the techniques they used and previous algorithms they built upon. 

Thankfully, they provided enough information for me to write this article and make a YouTube video out of it. Read on!

Training a Conversational Speech Model

Sesame is a Conversational Speech Model, or a CSM. It inputs both text and audio, and generates speech as audio. While they haven’t revealed their training data sources in the articles, we can still try to take a solid guess. The blog post heavily cites another CSM, 2024’s Moshi, and fortunately, the creators of Moshi did reveal their data sources in their paper. Moshi uses 7 million hours of unsupervised speech data, 170 hours of natural and scripted conversations (for multi-stream training), and 2000 more hours of telephone conversations (The Fischer Dataset).


Sesame builds upon the Moshi Paper (2024)

But what does it really take to generate audio?

In raw form, audio is just a long sequence of amplitude values — a waveform. For example, if you’re sampling audio at 24 kHz, you are capturing 24,000 float values every second.

There are 24000 values here to represent 1 second of speech! (Image generated by author)

Of course, it is quite resource-intensive to process 24000 float values for just one second of data, especially because transformer computations scale quadratically with sequence length. It would be great if we could compress this signal and reduce the number of samples required to process the audio.

We will take a deep dive into the Mimi encoder and specifically Residual Vector Quantizers (RVQ), which are the backbone of Audio/Speech modeling in Deep Learning today. We will end the article by learning about how Sesame generates audio using its special dual-transformer architecture.

Preprocessing audio

Compression and feature extraction are where convolution helps us. Sesame uses the Mimi speech encoder to process audio. Mimi was introduced in the aforementioned Moshi paper as well. Mimi is a self-supervised audio encoder-decoder model that converts audio waveforms into discrete “latent” tokens first, and then reconstructs the original signal. Sesame only uses the encoder section of Mimi to tokenize the input audio tokens. Let’s learn how.

Mimi inputs the raw speech waveform at 24Khz, passes them through several strided convolution layers to downsample the signal, with a stride factor of 4, 5, 6, 8, and 2. This means that the first CNN block downsamples the audio by 4x, then 5x, then 6x, and so on. In the end, it downsamples by a factor of 1920, reducing it to just 12.5 frames per second.

The convolution blocks also project the original float values to an embedding dimension of 512. Each embedding aggregates the local features of the original 1D waveform. 1 second of audio is now represented as around 12 vectors of size 512. This way, Mimi reduces the sequence length from 24000 to just 12 and converts them into dense continuous vectors.

Before applying any quantization, the Mimi Encoder downsamples the input 24KHz audio by 1920 times, and embeds it into 512 dimensions. In other words, you get 12.5 frames per second with each frame as a 512-dimensional vector. (Image from author’s video)

What is Audio Quantization?

Given the continuous embeddings obtained after the convolution layer, we want to tokenize the input speech. If we can represent speech as a sequence of tokens, we can apply standard language learning transformers to train generative models.

Mimi uses a Residual Vector Quantizer or RVQ tokenizer to achieve this. We will talk about the residual part soon, but first, let’s look at what a simple vanilla Vector quantizer does.

Vector Quantization

The idea behind Vector Quantization is simple: you train a codebook , which is a collection of, say, 1000 random vector codes all of size 512 (same as your embedding dimension).

A Vanilla Vector Quantizer. A codebook of embeddings is trained. Given an input embedding, we map/quantize it to the nearest codebook entry. (Screenshot from author’s video)

Then, given the input vector, we will map it to the closest vector in our codebook — basically snapping a point to its nearest cluster center. This means we have effectively created a fixed vocabulary of tokens to represent each audio frame, because whatever the input frame embedding may be, we will represent it with the nearest cluster centroid. If you want to learn more about Vector Quantization, check out my video on this topic where I go much deeper with this.

More about Vector Quantization! (Video by author)

Residual Vector Quantization

The problem with simple vector quantization is that the loss of information may be too high because we are mapping each vector to its cluster’s centroid. This “snap” is rarely perfect, so there is always an error between the original embedding and the nearest codebook.

The big idea of Residual Vector Quantization is that it doesn’t stop at having just one codebook. Instead, it tries to use multiple codebooks to represent the input vector.

  1. First, you quantize the original vector using the first codebook.
  2. Then, you subtract that centroid from your original vector. What you’re left with is the residual — the error that wasn’t captured in the first quantization.
  3. Now take this residual, and quantize it again, using a second codebook full of brand new code vectors — again by snapping it to the nearest centroid.
  4. Subtract that too, and you get a smaller residual. Quantize again with a third codebook… and you can keep doing this for as many codebooks as you want.
Residual Vector Quantizers (RVQ) hierarchically encode the input embeddings by using a new codebook and VQ layer to represent the previous codebook’s error. (Illustration by the author)

Each step hierarchically captures a little more detail that was missed in the previous round. If you repeat this for, let’s say, N codebooks, you get a collection of N discrete tokens from each stage of quantization to represent one audio frame.

The coolest thing about RVQs is that they are designed to have a high inductive bias towards capturing the most essential content in the very first quantizer. In the subsequent quantizers, they learn more and more fine-grained features.

If you’re familiar with PCA, you can think of the first codebook as containing the primary principal components, capturing the most critical information. The subsequent codebooks represent higher-order components, containing information that adds more details.

Residual Vector Quantizers (RVQ) uses multiple codebooks to encode the input vector — one entry from each codebook. (Screenshot from author’s video)

Acoustic vs Semantic Codebooks

Since Mimi is trained on the task of audio reconstruction, the encoder compresses the signal to the discretized latent space, and the decoder reconstructs it back from the latent space. When optimizing for this task, the RVQ codebooks learn to capture the essential acoustic content of the input audio inside the compressed latent space. 

Mimi also separately trains a single codebook (vanilla VQ) that only focuses on embedding the semantic content of the audio. This is why Mimi is called a split-RVQ tokenizer – it divides the quantization process into two independent parallel paths: one for semantic information and another for acoustic information.

The Mimi Architecture (Source: Moshi paper) License: Free

To train semantic representations, Mimi used knowledge distillation with an existing speech model called WavLM as a semantic teacher. Basically, Mimi introduces an additional loss function that decreases the cosine distance between the semantic RVQ code and the WavLM-generated embedding.


Audio Decoder

Given a conversation containing text and audio, we first convert them into a sequence of token embeddings using the text and audio tokenizers. This token sequence is then input into a transformer model as a time series. In the blog post, this model is referred to as the Autoregressive Backbone Transformer. Its task is to process this time series and output the “zeroth” codebook token.

A lighterweight transformer called the audio decoder then reconstructs the next codebook tokens conditioned on this zeroth code generated by the backbone transformer. Note that the zeroth code already contains a lot of information about the history of the conversation since the backbone transformer has visibility of the entire past sequence. The lightweight audio decoder only operates on the zeroth token and generates the other N-1 codes. These codes are generated by using N-1 distinct linear layers that output the probability of choosing each code from their corresponding codebooks. 

You can imagine this process as predicting a text token from the vocabulary in a text-only LLM. Just that a text-based LLM has a single vocabulary, but the RVQ-tokenizer has multiple vocabularies in the form of the N codebooks, so you need to train a separate linear layer to model the codes for each.

The Sesame Architecture (Illustration by the author)

Finally, after the codewords are all generated, we aggregate them to form the combined continuous audio embedding. The final job is to convert this audio back to a waveform. For this, we apply transposed convolutional layers to upscale the embedding back from 12.5 Hz back to KHz waveform audio. Basically, reversing the transforms we had applied originally during audio preprocessing.

In Summary

Check out the accompanying video on this article! (Video by author)

So, here is the overall summary of the Sesame model in some bullet points.

  1.  Sesame is built on a multimodal Conversation Speech Model or a CSM.
  2. Text and audio are tokenized together to form a sequence of tokens and input into the backbone transformer that autoregressively processes the sequence.
  3. While the text is processed like any other text-based LLM, the audio is processed directly from its waveform representation. They use the Mimi encoder to convert the waveform into latent codes using a split RVQ tokenizer.
  4. The multimodal backbone transformers consume a sequence of tokens and predict the next zeroth codeword.
  5.  Another lightweight transformer called the Audio Decoder predicts the next codewords from the zeroth codeword.
  6. The final audio frame representation is generated from combining all the generated codewords and upsampled back to the waveform representation.

Thanks for reading!

References and Must-read papers

Check out my ML YouTube Channel

Sesame Blogpost and Demo

Relevant papers: 
Moshi: https://arxiv.org/abs/2410.00037 
SoundStream: https://arxiv.org/abs/2107.03312 
HuBert: https://arxiv.org/abs/2106.07447 
Speech Tokenizer: https://arxiv.org/abs/2308.16692


The post Sesame  Speech Model:  How This Viral AI Model Generates Human-Like Speech appeared first on Towards Data Science.

]]>
Learnings from a Machine Learning Engineer — Part 6: The Human Side https://towardsdatascience.com/learnings-from-a-machine-learning-engineer-part-6-the-human-side/ Fri, 11 Apr 2025 18:44:39 +0000 https://towardsdatascience.com/?p=605720 Practical advice for the humans involved with machine learning

The post Learnings from a Machine Learning Engineer — Part 6: The Human Side appeared first on Towards Data Science.

]]>
In my previous articles, I have spent a lot of time talking about the technical aspects of an Image Classification problem from data collectionmodel evaluationperformance optimization, and a detailed look at model training.

These elements require a certain degree of in-depth expertise, and they (usually) have well-defined metrics and established processes that are within our control.

Now it’s time to consider…

The human aspects of machine learning

Yes, this may seem like an oxymoron! But it is the interaction with people — the ones you work with and the ones who use your application — that help bring the technology to life and provide a sense of fulfillment to your work.

These human interactions include:

  • Communicating technical concepts to a non-technical audience.
  • Understanding how your end-users engage with your application.
  • Providing clear expectations on what the model can and cannot do.

I also want to touch on the impact to people’s jobs, both positive and negative, as AI becomes a part of our everyday lives.

Overview

As in my previous articles, I will gear this discussion around an image classification application. With that in mind, these are the groups of people involved with your project:

  • AI/ML Engineer (that’s you) — bringing life to the Machine Learning application.
  • MLOps team — your peers who will deploy, monitor, and enhance your application.
  • Subject matter experts — the ones who will provide the care and feeding of labeled data.
  • Stakeholders — the ones who are looking for a solution to a real world problem.
  • End-users — the ones who will be using your application. These could be internal and external customers.
  • Marketing — the ones who will be promoting usage of your application.
  • Leadership — the ones who are paying the bill and need to see business value.

Let’s dive right in…

AI/ML Engineer

You may be a part of a team or a lone wolf. You may be an individual contributor or a team leader.

Photo by Christina @ wocintechchat.com on Unsplash

Whatever your role, it is important to see the whole picture — not only the coding, the data science, and the technology behind AI/ML — but the value that it brings to your organization.

Understand the business needs

Your company faces many challenges to reduce expenses, improve customer satisfaction, and remain profitable. Position yourself as someone who can create an application that helps achieve their goals.

  • What are the pain points in a business process?
  • What is the value of using your application (time savings, cost savings)?
  • What are the risks of a poor implementation?
  • What is the roadmap for future enhancements and use-cases?
  • What other areas of the business could benefit from the application, and what design choices will help future-proof your work?

Communication

Deep technical discussions with your peers is probably our comfort zone. However, to be a more successful AI/ML Engineer, you should be able to clearly explain the work you are doing to different audiences.

With practice, you can explain these topics in ways that your non-technical business users can follow along with, and understand how your technology will benefit them.

To help you get comfortable with this, try creating a PowerPoint with 2–3 slides that you can cover in 5–10 minutes. For example, explain how a neural network can take an image of a cat or a dog and determine which one it is.

Practice giving this presentation in your mind, to a friend — even your pet dog or cat! This will get you more comfortable with the transitions, tighten up the content, and ensure you cover all the important points as clearly as possible.

  • Be sure to include visuals — pure text is boring, graphics are memorable.
  • Keep an eye on time — respect your audience’s busy schedule and stick to the 5–10 minutes you are given.
  • Put yourself in their shoes — your audience is interested in how the technology will benefit them, not on how smart you are.

Creating a technical presentation is a lot like the Feynman Technique — explaining a complex subject to your audience by breaking it into easily digestible pieces, with the added benefit of helping you understand it more completely yourself.

MLOps team

These are the people that deploy your application, manage data pipelines, and monitor infrastructure that keeps things running.

Without them, your model lives in a Jupyter notebook and helps nobody!

Photo by airfocus on Unsplash

These are your technical peers, so you should be able to connect with their skillset more naturally. You speak in jargon that sounds like a foreign language to most people. Even so, it is extremely helpful for you to create documentation to set expectations around:

  • Process and data flows.
  • Data quality standards.
  • Service level agreements for model performance and availability.
  • Infrastructure requirements for compute and storage.
  • Roles and responsibilities.

It is easy to have a more informal relationship with your MLOps team, but remember that everyone is trying to juggle many projects at the same time.

Email and chat messages are fine for quick-hit issues. But for larger tasks, you will want a system to track things like user stories, enhancement requests, and break-fix issues. This way you can prioritize the work and ensure you don’t forget something. Plus, you can show progress to your supervisor.

Some great tools exist, such as:

  • Jira, GitHub, Azure DevOps Boards, Asana, Monday, etc.

We are all professionals, so having a more formal system to avoid miscommunication and mistrust is good business.

Subject matter experts

These are the team members that have the most experience working with the data that you will be using in your AI/ML project.

Photo by National Cancer Institute on Unsplash

SMEs are very skilled at dealing with messy data — they are human, after all! They can handle one-off situations by considering knowledge outside of their area of expertise. For example, a doctor may recognize metal inserts in a patient’s X-ray that indicate prior surgery. They may also notice a faulty X-ray image due to equipment malfunction or technician error.

However, your machine learning model only knows what it knows, which comes from the data it was trained on. So, those one-off cases may not be appropriate for the model you are training. Your SMEs need to understand that clear, high quality training material is what you are looking for.

Think like a computer

In the case of an image classification application, the output from the model communicates to you how well it was trained on the data set. This comes in the form of error rates, which is very much like when a student takes an exam and you can tell how well they studied by seeing how many questions — and which ones — they get wrong.

In order to reduce error rates, your image data set needs to be objectively “good” training material. To do this, put yourself in an analytical mindset and ask yourself:

  • What images will the computer get the most useful information out of? Make sure all the relevant features are visible.
  • What is it about an image that confused the model? When it makes an error, try to understand why — objectively — by looking at the entire picture.
  • Is this image a “one-off” or a typical example of what the end-users will send? Consider creating a new subclass of exceptions to the norm.

Be sure to communicate to your SMEs that model performance is directly tied to data quality and give them clear guidance:

  • Provide visual examples of what works.
  • Provide counter-examples of what does not work.
  • Ask for a wide variety of data points. In the X-ray example, be sure to get patients with different ages, genders, and races.
  • Provide options to create subclasses of your data for further refinement. Use that X-ray from a patient with prior surgery as a subclass, and eventually as you can get more examples over time, the model can handle them.

This also means that you should become familiar with the data they are working with — perhaps not expert level, but certainly above a novice level.

Lastly, when working with SMEs, be cognizant of the impression they may have that the work you are doing is somehow going to replace their job. It can feel threatening when someone asks you how to do your job, so be mindful.

Ideally, you are building a tool with honest intentions and it will enable your SMEs to augment their day-to-day work. If they can use the tool as a second opinion to validate their conclusions in less time, or perhaps even avoid mistakes, then this is a win for everyone. Ultimately, the goal is to allow them to focus on more challenging situations and achieve better outcomes.

I have more to say on this in my closing remarks.

Stakeholders

These are the people you will have the closest relationship with.

Stakeholders are the ones who created the business case to have you build the machine learning model in the first place.

Photo by Ninthgrid on Unsplash

They have a vested interest in having a model that performs well. Here are some key point when working with your stakeholder:

  • Be sure to listen to their needs and requirements.
  • Anticipate their questions and be prepared to respond.
  • Be on the lookout for opportunities to improve your model performance. Your stakeholders may not be as close to the technical details as you are and may not think there is any room for improvement.
  • Bring issues and problems to their attention. They may not want to hear bad news, but they will appreciate honesty over evasion.
  • Schedule regular updates with usage and performance reports.
  • Explain technical details in terms that are easy to understand.
  • Set expectations on regular training and deployment cycles and timelines.

Your role as an AI/ML Engineer is to bring to life the vision of your stakeholders. Your application is making their lives easier, which justifies and validates the work you are doing. It’s a two-way street, so be sure to share the road.

End-users

These are the people who are using your application. They may also be your harshest critics, but you may never even hear their feedback.

Photo by Alina Ruf on Unsplash

Think like a human

Recall above when I suggested to “think like a computer” when analyzing the data for your training set. Now it’s time to put yourself in the shoes of a non-technical user of your application.

End-users of an image classification model communicate their understanding of what’s expected of them by way of poor images. These are like the students that didn’t study for the exam, or worse didn’t read the questions, so their answers don’t make sense.

Your model may be really good, but if end-users misuse the application or are not satisfied with the output, you should be asking:

  • Are the instructions confusing or misleading? Did the user focus the camera on the subject being classified, or is it more of a wide-angle image? You can’t blame the user if they follow bad instructions.
  • What are their expectations? When the results are presented to the user, are they satisfied or are they frustrated? You may noticed repeated images from frustrated users.
  • Are the usage patterns changing? Are they trying to use the application in unexpected ways? This may be an opportunity to improve the model.

Inform your stakeholders of your observations. There may be simple fixes to improve end-user satisfaction, or there may be more complex work ahead.

If you are lucky, you may discover an unexpected way to leverage the application that leads to expanded usage or exciting benefits to your business.

Explainability

Most AI/ML model are considered “black boxes” that perform millions of calculations on extremely high dimensional data and produce a rather simplistic result without any reason behind it.

The Answer to Ultimate Question of Life, the Universe, and Everything is 42.
— The Hitchhikers Guide to the Galaxy

Depending on the situation, your end-users may require more explanation of the results, such as with medical imaging. Where possible, you should consider incorporating model explainability techniques such as LIME, SHAP, and others. These responses can help put a human touch to cold calculations.

Now it’s time to switch gears and consider higher-ups in your organization.

Marketing team

These are the people who promote the use of your hard work. If your end-users are completely unaware of your application, or don’t know where to find it, your efforts will go to waste.

The marketing team controls where users can find your app on your website and link to it through social media channels. They also see the technology through a different lens.

Gartner hype cycle. Image from Wikipedia – https://en.wikipedia.org/wiki/Gartner_hype_cycle

The above hype cycle is a good representation of how technical advancements tends to flow. At the beginning, there can be an unrealistic expectation of what your new AI/ML tool can do — it’s the greatest thing since sliced bread!

Then the “new” wears off and excitement wanes. You may face a lack of interest in your application and the marketing team (as well as your end-users) move on to the next thing. In reality, the value of your efforts are somewhere in the middle.

Understand that the marketing team’s interest is in promoting the use of the tool because of how it will benefit the organization. They may not need to know the technical inner workings. But they should understand what the tool can do, and be aware of what it cannot do.

Honest and clear communication up-front will help smooth out the hype cycle and keep everyone interested longer. This way the crash from peak expectations to the trough of disillusionment is not so severe that the application is abandoned altogether.

Leadership team

These are the people that authorize spending and have the vision for how the application fits into the overall company strategy. They are driven by factors that you have no control over and you may not even be aware of. Be sure to provide them with the key information about your project so they can make informed decisions.

Photo by Adeolu Eletu on Unsplash

Depending on your role, you may or may not have direct interaction with executive leadership in your company. Your job is to summarize the costs and benefits associated with your project, even if that is just with your immediate supervisor who will pass this along.

Your costs will likely include:

  • Compute and storage — training and serving a model.
  • Image data collection — both real-world and synthetic or staged.
  • Hours per week — SME, MLOps, AI/ML engineering time.

Highlight the savings and/or value added:

  • Provide measures on speed and accuracy.
  • Translate efficiencies into FTE hours saved and customer satisfaction.
  • Bonus points if you can find a way to produce revenue.

Business leaders, much like the marketing team, may follow the hype cycle:

  • Be realistic about model performance. Don’t try to oversell it, but be honest about the opportunities for improvement.
  • Consider creating a human benchmark test to measure accuracy and speed for an SME. It is easy to say human accuracy is 95%, but it’s another thing to measure it.
  • Highlight short-term wins and how they can become long-term success.

Conclusion

I hope you can see that, beyond the technical challenges of creating an AI/ML application, there are many humans involved in a successful project. Being able to interact with these individuals, and meet them where they are in terms of their expectations from the technology, is vital to advancing the adoption of your application.

Photo by Vlad Hilitanu on Unsplash

Key takeaways:

  • Understand how your application fits into the business needs.
  • Practice communicating to a non-technical audience.
  • Collect measures of model performance and report these regularly to your stakeholders.
  • Expect that the hype cycle could help and hurt your cause, and that setting consistent and realistic expectations will ensure steady adoption.
  • Be aware that factors outside of your control, such as budgets and business strategy, could affect your project.

And most importantly…

Don’t let machines have all the fun learning!

Human nature gives us the curiosity we need to understand our world. Take every opportunity to grow and expand your skills, and remember that human interaction is at the heart of machine learning.

Closing remarks

Advancements in AI/ML have the potential (assuming they are properly developed) to do many tasks as well as humans. It would be a stretch to say “better than” humans because it can only be as good as the training data that humans provide. However, it is safe to say AI/ML can be faster than humans.

The next logical question would be, “Well, does that mean we can replace human workers?”

This is a delicate topic, and I want to be clear that I am not an advocate of eliminating jobs.

I see my role as an AI/ML Engineer as being one that can create tools that aide in someone else’s job or enhance their ability to complete their work successfully. When used properly, the tools can validate difficult decisions and speed through repetitive tasks, allowing your experts to spend more time on the one-off situations that require more attention.

There may also be new career opportunities, from the care-and-feeding of data, quality assessment, user experience, and even to new roles that leverage the technology in exciting and unexpected ways.

Unfortunately, business leaders may make decisions that impact people’s jobs, and this is completely out of your control. But all is not lost — even for us AI/ML Engineers…

There are things we can do

  • Be kind to the fellow human beings that we call “coworkers”.
  • Be aware of the fear and uncertainty that comes with technological advancements.
  • Be on the lookout for ways to help people leverage AI/ML in their careers and to make their lives better.

This is all part of being human.

The post Learnings from a Machine Learning Engineer — Part 6: The Human Side appeared first on Towards Data Science.

]]>
Are You Sure Your Posterior Makes Sense? https://towardsdatascience.com/are-you-sure-your-posterior-makes-sense/ Fri, 11 Apr 2025 18:38:41 +0000 https://towardsdatascience.com/?p=605717 A detailed guide on how to use diagnostics to evaluate the performance of MCMC samplers

The post Are You Sure Your Posterior Makes Sense? appeared first on Towards Data Science.

]]>
This article is co-authored by Felipe Bandeira, Giselle Fretta, Thu Than, and Elbion Redenica. We also thank Prof. Carl Scheffler for his support.

Introduction

Parameter estimation has been for decades one of the most important topics in statistics. While frequentist approaches, such as Maximum Likelihood Estimations, used to be the gold standard, the advance of computation has opened space for Bayesian methods. Estimating posterior distributions with Mcmc samplers became increasingly common, but reliable inferences depend on a task that is far from trivial: making sure that the sampler — and the processes it executes under the hood — worked as expected. Keeping in mind what Lewis Caroll once wrote: “If you don’t know where you’re going, any road will take you there.”

This article is meant to help data scientists evaluate an often overlooked aspect of Bayesian parameter estimation: the reliability of the sampling process. Throughout the sections, we combine simple analogies with technical rigor to ensure our explanations are accessible to data scientists with any level of familiarity with Bayesian methods. Although our implementations are in Python with PyMC, the concepts we cover are useful to anyone using an MCMC algorithm, from Metropolis-Hastings to NUTS. 

Key Concepts

No data scientist or statistician would disagree with the importance of robust parameter estimation methods. Whether the objective is to make inferences or conduct simulations, having the capacity to model the data generation process is a crucial part of the process. For a long time, the estimations were mainly performed using frequentist tools, such as Maximum Likelihood Estimations (MLE) or even the famous Least Squares optimization used in regressions. Yet, frequentist methods have clear shortcomings, such as the fact that they are focused on point estimates and do not incorporate prior knowledge that could improve estimates.

As an alternative to these tools, Bayesian methods have gained popularity over the past decades. They provide statisticians not only with point estimates of the unknown parameter but also with confidence intervals for it, all of which are informed by the data and by the prior knowledge researchers held. Originally, Bayesian parameter estimation was done through an adapted version of Bayes’ theorem focused on unknown parameters (represented as θ) and known data points (represented as x). We can define P(θ|x), the posterior distribution of a parameter’s value given the data, as:

\[ P(\theta|x) = \frac{P(x|\theta) P(\theta)}{P(x)} \]

In this formula, P(x|θ) is the likelihood of the data given a parameter value, P(θ) is the prior distribution over the parameter, and P(x) is the evidence, which is computed by integrating all possible values of the prior:

\[ P(x) = \int_\theta P(x, \theta) d\theta \]

In some cases, due to the complexity of the calculations required, deriving the posterior distribution analytically was not possible. However, with the advance of computation, running sampling algorithms (especially MCMC ones) to estimate posterior distributions has become easier, giving researchers a powerful tool for situations where analytical posteriors are not trivial to find. Yet, with such power also comes a large amount of responsibility to ensure that results make sense. This is where sampler diagnostics come in, offering a set of valuable tools to gauge 1) whether an MCMC algorithm is working well and, consequently, 2) whether the estimated distribution we see is an accurate representation of the real posterior distribution. But how can we know so?

How samplers work

Before diving into the technicalities of diagnostics, we shall cover how the process of sampling a posterior (especially with an MCMC sampler) works. In simple terms, we can think of a posterior distribution as a geographical area we haven’t been to but need to know the topography of. How can we draw an accurate map of the region?  

One of our favorite analogies comes from Ben Gilbert. Suppose that the unknown region is actually a house whose floorplan we wish to map. For some reason, we cannot directly visit the house, but we can send bees inside with GPS devices attached to them. If everything works as expected, the bees will fly around the house, and using their trajectories, we can estimate what the floor plan looks like. In this analogy, the floor plan is the posterior distribution, and the sampler is the group of bees flying around the house.

The reason we are writing this article is that, in some cases, the bees won’t fly as expected. If they get stuck in a certain room for some reason (because someone dropped sugar on the floor, for example), the data they return won’t be representative of the entire house; rather than visiting all rooms, the bees only visited a few, and our picture of what the house looks like will ultimately be incomplete. Similarly, when a sampler does not work correctly, our estimation of the posterior distribution is also incomplete, and any inference we draw based on it is likely to be wrong.

Monte Carlo Markov Chain (MCMC)

In technical terms, we call an MCMC process any algorithm that undergoes transitions from one state to another with certain properties. Markov Chain refers to the fact that the next state only depends on the current one (or that the bee’s next location is only influenced by its current place, and not by all of the places where it has been before). Monte Carlo means that the next state is chosen randomly. MCMC methods like Metropolis-Hastings, Gibbs sampling, Hamiltonian Monte Carlo (HMC), and No-U-Turn Sampler (NUTS) all operate by constructing Markov Chains (a sequence of steps) that are close to random and gradually explore the posterior distribution.

Now that you understand how a sampler works, let’s dive into a practical scenario to help us explore sampling problems.

Case Study

Imagine that, in a faraway nation, a governor wants to understand more about public annual spending on healthcare by mayors of cities with less than 1 million inhabitants. Rather than looking at sheer frequencies, he wants to understand the underlying distribution explaining expenditure, and a sample of spending data is about to arrive. The problem is that two of the economists involved in the project disagree about how the model should look.

Model 1

The first economist believes that all cities spend similarly, with some variation around a certain mean. As such, he creates a simple model. Although the specifics of how the economist chose his priors are irrelevant to us, we do need to keep in mind that he is trying to approximate a Normal (unimodal) distribution.

\[
x_i \sim \text{Normal}(\mu, \sigma^2) \text{ i.i.d. for all } i \\
\mu \sim \text{Normal}(10, 2) \\
\sigma^2 \sim \text{Uniform}(0,5)
\]

Model 2

The second economist disagrees, arguing that spending is more complex than his colleague believes. He believes that, given ideological differences and budget constraints, there are two kinds of cities: the ones that do their best to spend very little and the ones that are not afraid of spending a lot. As such, he creates a slightly more complex model, using a mixture of normals to reflect his belief that the true distribution is bimodal.

\[
x_i \sim \text{Normal-Mixture}([\omega, 1-\omega], [m_1, m_2], [s_1^2, s_2^2]) \text{ i.i.d. for all } i\\
m_j \sim \text{Normal}(2.3, 0.5^2) \text{ for } j = 1,2 \\
s_j^2 \sim \text{Inverse-Gamma}(1,1) \text{ for } j=1,2 \\
\omega \sim \text{Beta}(1,1)
\]

After the data arrives, each economist runs an MCMC algorithm to estimate their desired posteriors, which will be a reflection of reality (1) if their assumptions are true and (2) if the sampler worked correctly. The first if, a discussion about assumptions, shall be left to the economists. However, how can they know whether the second if holds? In other words, how can they be sure that the sampler worked correctly and, as a consequence, their posterior estimations are unbiased?

Sampler Diagnostics

To evaluate a sampler’s performance, we can explore a small set of metrics that reflect different parts of the estimation process.

Quantitative Metrics

R-hat (Potential Scale Reduction Factor)

In simple terms, R-hat evaluates whether bees that started at different places have all explored the same rooms at the end of the day. To estimate the posterior, an MCMC algorithm uses multiple chains (or bees) that start at random locations. R-hat is the metric we use to assess the convergence of the chains. It measures whether multiple MCMC chains have mixed well (i.e., if they have sampled the same topography) by comparing the variance of samples within each chain to the variance of the sample means across chains. Intuitively, this means that

\[
\hat{R} = \sqrt{\frac{\text{Variance Between Chains}}{\text{Variance Within Chains}}}
\]

If R-hat is close to 1.0 (or below 1.01), it means that the variance within each chain is very similar to the variance between chains, suggesting that they have converged to the same distribution. In other words, the chains are behaving similarly and are also indistinguishable from one another. This is precisely what we see after sampling the posterior of the first model, shown in the last column of the table below:

Figure 1. Summary statistics of the sampler highlighting ideal R-hats.

The r-hat from the second model, however, tells a different story. The fact we have such large r-hat values indicates that, at the end of the sampling process, the different chains had not converged yet. In practice, this means that the distribution they explored and returned was different, or that each bee created a map of a different room of the house. This fundamentally leaves us without a clue of how the pieces connect or what the complete floor plan looks like.

Figure 2. Summary statistics of the sampler showcasing problematic R-hats.

Given our R-hat readouts were large, we know something went wrong with the sampling process in the second model. However, even if the R-hat had turned out within acceptable levels, this does not give us certainty that the sampling process worked. R-hat is just a diagnostic tool, not a guarantee. Sometimes, even if your R-hat readout is lower than 1.01, the sampler might not have properly explored the full posterior. This happens when multiple bees start their exploration in the same room and remain there. Likewise, if you’re using a small number of chains, and if your posterior happens to be multimodal, there is a probability that all chains started in the same mode and failed to explore other peaks. 

The R-hat readout reflects convergence, not completion. In order to have a more comprehensive idea, we need to check other diagnostic metrics as well.

Effective Sample Size (ESS)

When explaining what MCMC was, we mentioned that “Monte Carlo” refers to the fact that the next state is chosen randomly. This does not necessarily mean that the states are fully independent. Even though the bees choose their next step at random, these steps are still correlated to some extent. If a bee is exploring a living room at time t=0, it will probably still be in the living room at time t=1, even though it is in a different part of the same room. Due to this natural connection between samples, we say these two data points are autocorrelated.

Due to their nature, MCMC methods inherently produce autocorrelated samples, which complicates statistical analysis and requires careful evaluation. In statistical inference, we often assume independent samples to ensure that the estimates of uncertainty are accurate, hence the need for uncorrelated samples. If two data points are too similar to each other, the correlation reduces their effective information content. Mathematically, the formula below represents the autocorrelation function between two time points (t1 and t2) in a random process:

\[
R_{XX}(t_1, t_2) = E[X_{t_1} \overline{X_{t_2}}]
\]

where E is the expected value operator and X-bar is the complex conjugate. In MCMC sampling, this is crucial because high autocorrelation means that new samples don’t teach us anything different from the old ones, effectively reducing the sample size we have. Unsurprisingly, the metric that reflects this is called Effective Sample Size (ESS), and it helps us determine how many truly independent samples we have. 

As hinted previously, the effective sample size accounts for autocorrelation by estimating how many truly independent samples would provide the same information as the autocorrelated samples we have. Mathematically, for a parameter θ, the ESS is defined as:

\[
ESS = \frac{n}{1 + 2 \sum_{k=1}^{\infty} \rho(\theta)_k}
\]

where n is the total number of samples and ρ(θ)k is the autocorrelation at lag k for parameter θ.

Typically, for ESS readouts, the higher, the better. This is what we see in the readout for the first model. Two common ESS variations are Bulk-ESS, which assesses mixing in the central part of the distribution, and Tail-ESS, which focuses on the efficiency of sampling the distribution’s tails. Both inform us if our model accurately reflects the central tendency and credible intervals.

Figure 3. Summary statistics of the sampler highlighting ideal quantities for ESS bulk and tail.

In contrast, the readouts for the second model are very bad. Typically, we want to see readouts that are at least 1/10 of the total sample size. In this case, given each chain sampled 2000 observations, we should expect ESS readouts of at least 800 (from the total size of 8000 samples across 4 chains of 2000 samples each), which is not what we observe.

Figure 4. Summary statistics of the sampler demonstrating problematic ESS bulk and tail.

Visual Diagnostics

Apart from the numerical metrics, our understanding of sampler performance can be deepened through the use of diagnostic plots. The main ones are rank plots, trace plots, and pair plots.

Rank Plots

A rank plot helps us identify whether the different chains have explored all of the posterior distribution. If we once again think of the bee analogy, rank plots tell us which bees explored which parts of the house. Therefore, to evaluate whether the posterior was explored equally by all chains, we observe the shape of the rank plots produced by the sampler. Ideally, we want the distribution of all chains to look roughly uniform, like in the rank plots generated after sampling the first model. Each color below represents a chain (or bee):

Figure 5. Rank plots for parameters ‘m’ and ‘s’ across four MCMC chains. Each bar represents the distribution of rank values for one chain, with ideally uniform ranks indicating good mixing and proper convergence.

Under the hood, a rank plot is produced with a simple sequence of steps. First, we run the sampler and let it sample from the posterior of each parameter. In our case, we are sampling posteriors for parameters m and s of the first model. Then, parameter by parameter, we get all samples from all chains, put them together, and order them from smallest to largest. We then ask ourselves, for each sample, what was the chain where it came from? This will allow us to create plots like the ones we see above. 

In contrast, bad rank plots are easy to spot. Unlike the previous example, the distributions from the second model, shown below, are not uniform. From the plots, what we interpret is that each chain, after beginning at different random locations, got stuck in a region and did not explore the entirety of the posterior. Consequently, we cannot make inferences from the results, as they are unreliable and not representative of the true posterior distribution. This would be equivalent to having four bees that started at different rooms of the house and got stuck somewhere during their exploration, never covering the entirety of the property.

Figure 6. Rank plots for parameters m, s_squared, and w across four MCMC chains. Each subplot shows the distribution of ranks by chain. There are noticeable deviations from uniformity (e.g., stair-step patterns or imbalances across chains) suggesting potential sampling issues.

KDE and Trace Plots

Similar to R-hat, trace plots help us assess the convergence of MCMC samples by visualizing how the algorithm explores the parameter space over time. PyMC provides two types of trace plots to diagnose mixing issues: Kernel Density Estimate (KDE) plots and iteration-based trace plots. Each of these serves a distinct purpose in evaluating whether the sampler has properly explored the target distribution.

The KDE plot (usually on the left) estimates the posterior density for each chain, where each line represents a separate chain. This allows us to check whether all chains have converged to the same distribution. If the KDEs overlap, it suggests that the chains are sampling from the same posterior and that mixing has occurred. On the other hand, the trace plot (usually on the right) visualizes how parameter values change over MCMC iterations (steps), with each line representing a different chain. A well-mixed sampler will produce trace plots that look noisy and random, with no clear structure or separation between chains.

Using the bee analogy, trace plots can be thought of as snapshots of the “features” of the house at different locations. If the sampler is working correctly, the KDEs in the left plot should align closely, showing that all bees (chains) have explored the house similarly. Meanwhile, the right plot should show highly variable traces that blend together, confirming that the chains are actively moving through the space rather than getting stuck in specific regions.

Figure 7. Density and trace plots for parameters m and s from the first model across four MCMC chains. The left panel shows kernel density estimates (KDE) of the marginal posterior distribution for each chain, indicating consistent central tendency and spread. The right panel displays the trace plot over iterations, with overlapping chains and no apparent divergences, suggesting good mixing and convergence.

However, if your sampler has poor mixing or convergence issues, you will see something like the figure below. In this case, the KDEs will not overlap, meaning that different chains have sampled from different distributions rather than a shared posterior. The trace plot will also show structured patterns instead of random noise, indicating that chains are stuck in different regions of the parameter space and failing to fully explore it.

Figure 8. KDE (left) and trace plots (right) for parameters m, s_squared, and w across MCMC chains for the second model. Multimodal distributions are visible for m and w, suggesting potential identifiability issues. Trace plots reveal that chains explore different modes with limited mixing, particularly for m, highlighting challenges in convergence and effective sampling.

By using trace plots alongside the other diagnostics, you can identify sampling issues and determine whether your MCMC algorithm is effectively exploring the posterior distribution.

Pair Plots

A third kind of plot that is often useful for diagnostic are pair plots. In models where we want to estimate the posterior distribution of multiple parameters, pair plots allow us to observe how different parameters are correlated. To understand how such plots are formed, think again about the bee analogy. If you imagine that we’ll create a plot with the width and length of the house, each “step” that the bees take can be represented by an (x, y) combination. Likewise, each parameter of the posterior is represented as a dimension, and we create scatter plots showing where the sampler walked using parameter values as coordinates. Here, we are plotting each unique pair (x, y), resulting in the scatter plot you see in the middle of the image below. The one-dimensional plots you see on the edges are the marginal distributions over each parameter, giving us additional information on the sampler’s behavior when exploring them.

Take a look at the pair plot from the first model.

Figure 9. Joint posterior distribution of parameters m and s, with marginal densities. The scatter plot shows a roughly symmetric, elliptical shape, suggesting a low correlation between m and s.

Each axis represents one of the two parameters whose posteriors we are estimating. For now, let’s focus on the scatter plot in the middle, which shows the parameter combinations sampled from the posterior. The fact we have a very even distribution means that, for any particular value of m, there was a range of values of s that were equally likely to be sampled. Additionally, we don’t see any correlation between the two parameters, which is usually good! There are cases when we would expect some correlation, such as when our model involves a regression line. However, in this instance, we have no reason to believe two parameters should be highly correlated, so the fact we don’t observe unusual behavior is positive news. 

Now, take a look at the pair plots from the second model.

Figure 10. Pair plot of the joint posterior distributions for parameters m, s_squared, and w. The scatter plots reveal strong correlations between several parameters.

Given that this model has five parameters to be estimated, we naturally have a greater number of plots since we are analyzing them pair-wise. However, they look odd compared to the previous example. Namely, rather than having an even distribution of points, the samples here either seem to be divided across two regions or seem somewhat correlated. This is another way of visualizing what the rank plots have shown: the sampler did not explore the full posterior distribution. Below we isolated the top left plot, which contains the samples from m0 and m1. Unlike the plot from model 1, here we see that the value of one parameter greatly influences the value of the other. If we sampled m1 around 2.5, for example, m0 is likely to be sampled from a very narrow range around 1.5.

Figure 11. Joint posterior distribution of parameters m₀ and m₁, with marginal densities.

Certain shapes can be observed in problematic pair plots relatively frequently. Diagonal patterns, for example, indicate a high correlation between parameters. Banana shapes are often connected to parametrization issues, often being present in models with tight priors or constrained parameters. Funnel shapes might indicate hierarchical models with bad geometry. When we have two separate islands, like in the plot above, this can indicate that the posterior is bimodal AND that the chains haven’t mixed well. However, keep in mind that these shapes might indicate problems, but not necessarily do so. It’s up to the data scientist to examine the model and determine which behaviors are expected and which ones are not!

Some Fixing Techniques

When your diagnostics indicate sampling problems — whether concerning R-hat values, low ESS, unusual rank plots, separated trace plots, or strange parameter correlations in pair plots — several strategies can help you address the underlying issues. Sampling problems typically stem from the target posterior being too complex for the sampler to explore efficiently. Complex target distributions might have:

  • Multiple modes (peaks) that the sampler struggles to move between
  • Irregular shapes with narrow “corridors” connecting different regions
  • Areas of drastically different scales (like the “neck” of a funnel)
  • Heavy tails that are difficult to sample accurately

In the bee analogy, these complexities represent houses with unusual floor plans — disconnected rooms, extremely narrow hallways, or areas that change dramatically in size. Just as bees might get trapped in specific regions of such houses, MCMC chains can get stuck in certain areas of the posterior.

Figure 12. Examples of multimodal target distributions.
Figure 13. Examples of weirdly shaped distributions.

To help the sampler in its exploration, there are simple strategies we can use.

Strategy 1: Reparameterization

Reparameterization is particularly effective for hierarchical models and distributions with challenging geometries. It involves transforming your model’s parameters to make them easier to sample. Back to the bee analogy, imagine the bees are exploring a house with a peculiar layout: a spacious living room that connects to the kitchen through a very, very narrow hallway. One aspect we hadn’t mentioned before is that the bees have to fly in the same way through the entire house. That means that if we dictate the bees should use large “steps,” they will explore the living room very well but hit the walls in the hallway head-on. Likewise, if their steps are small, they will explore the narrow hallway well, but take forever to cover the entire living room. The difference in scales, which is natural to the house, makes the bees’ job more difficult.

A classic example that represents this scenario is Neal’s funnel, where the scale of one parameter depends on another:

\[
p(y, x) = \text{Normal}(y|0, 3) \times \prod_{n=1}^{9} \text{Normal}(x_n|0, e^{y/2})
\]

Figure 14. Log the marginal density of y and the first dimension of Neal’s funnel. The neck is where the sampler is struggling to sample from and the step size is required to be much smaller than the body. (Image source: Stan User’s Guide)

We can see that the scale of x is dependent on the value of y. To fix this problem, we can separate x and y as independent standard Normals and then transform these variables into the desired funnel distribution. Instead of sampling directly like this:

\[
\begin{align*}
y &\sim \text{Normal}(0, 3) \\
x &\sim \text{Normal}(0, e^{y/2})
\end{align*}
\]

You can reparameterize to sample from standard Normals first:

\[
y_{raw} \sim \text{Standard Normal}(0, 1) \\
x_{raw} \sim \text{Standard Normal}(0, 1) \\
\\
y = 3y_{raw} \\
x = e^{y/2} x_{raw}
\]

This technique separates the hierarchical parameters and makes sampling more efficient by eliminating the dependency between them. 

Reparameterization is like redesigning the house such that instead of forcing the bees to find a single narrow hallway, we create a new layout where all passages have similar widths. This helps the bees use a consistent flying pattern throughout their exploration.

Strategy 2: Handling Heavy-tailed Distributions

Heavy-tailed distributions like Cauchy and Student-T present challenges for samplers and the ideal step size. Their tails require larger step sizes than their central regions (similar to very long hallways that require the bees to travel long distances), which creates a challenge:

  • Small step sizes lead to inefficient sampling in the tails
  • Large step sizes cause too many rejections in the center
Figure 15. Probability density functions for various Cauchy distributions illustrate the effects of changing the location parameter and scale parameter. (Image source: Wikipedia)

Reparameterization solutions include:

  • For Cauchy: Defining the variable as a transformation of a Uniform distribution using the Cauchy inverse CDF
  • For Student-T: Using a Gamma-Mixture representation

Strategy 3: Hyperparameter Tuning

Sometimes the solution lies in adjusting the sampler’s hyperparameters:

  • Increase total iterations: The simplest approach — give the sampler more time to explore.
  • Increase target acceptance rate (adapt_delta): Reduce divergent transitions (try 0.9 instead of the default 0.8 for complex models, for example).
  • Increase max_treedepth: Allow the sampler to take more steps per iteration.
  • Extend warmup/adaptation phase: Give the sampler more time to adapt to the posterior geometry.

Remember that while these adjustments may improve your diagnostic metrics, they often treat symptoms rather than underlying causes. The previous strategies (reparameterization and better proposal distributions) typically offer more fundamental solutions.

Strategy 4: Better Proposal Distributions

This solution is for function fitting processes, rather than sampling estimations of the posterior. It basically asks the question: “I’m currently here in this landscape. Where should I jump to next so that I explore the full landscape, or how do I know that the next jump is the jump I should make?” Thus, choosing a good distribution means making sure that the sampling process explores the full parameter space instead of just a specific region. A good proposal distribution should:

  1. Have substantial probability mass where the target distribution does.
  2. Allow the sampler to make jumps of the appropriate size.

One common choice of the proposal distribution is the Gaussian (Normal) distribution with mean μ and standard deviation σ — the scale of the distribution that we can tune to decide how far to jump from the current position to the next position. If we choose the scale for the proposal distribution to be too small, it might either take too long to explore the entire posterior or it will get stuck in a region and never explore the full distribution. But if the scale is too large, you might never get to explore some regions, jumping over them. It’s like playing ping-pong where we only reach the two edges but not the middle.

Improve Prior Specification

When all else fails, reconsider your model’s prior specifications. Vague or weakly informative priors (like uniformly distributed priors) can sometimes lead to sampling difficulties. More informative priors, when justified by domain knowledge, can help guide the sampler toward more reasonable regions of the parameter space. Sometimes, despite your best efforts, a model may remain challenging to sample effectively. In such cases, consider whether a simpler model might achieve similar inferential goals while being more computationally tractable. The best model is often not the most complex one, but the one that balances complexity with reliability. The table below shows the summary of fixing strategies for different issues.

Diagnostic SignalPotential IssueRecommended Fix
High R-hatPoor mixing between chainsIncrease iterations, adjust the step size
Low ESSHigh autocorrelationReparameterization, increase adapt_delta
Non-uniform rank plotsChains stuck in different regionsBetter proposal distribution, start with multiple chains
Separated KDEs in trace plotsChains exploring different distributionsReparameterization
Funnel shapes in pair plotsHierarchical model issuesNon-centered reparameterization
Disjoint clusters in pair plotsMultimodality with poor mixingAdjusted distribution, simulated annealing

Conclusion

Assessing the quality of MCMC sampling is crucial for ensuring reliable inference. In this article, we explored key diagnostic metrics such as R-hat, ESS, rank plots, trace plots, and pair plots, discussing how each helps determine whether the sampler is performing properly.

If there’s one takeaway we want you to keep in mind it’s that you should always run diagnostics before drawing conclusions from your samples. No single metric provides a definitive answer — each serves as a tool that highlights potential issues rather than proving convergence. When problems arise, strategies such as reparameterization, hyperparameter tuning, and prior specification can help improve sampling efficiency.

By combining these diagnostics with thoughtful modeling decisions, you can ensure a more robust analysis, reducing the risk of misleading inferences due to poor sampling behavior.

References

B. Gilbert, Bob’s bees: the importance of using multiple bees (chains) to judge MCMC convergence (2018), Youtube

Chi-Feng, MCMC demo (n.d.), GitHub

D. Simpson, Maybe it’s time to let the old ways die; or We broke R-hat so now we have to fix it. (2019), Statistical Modeling, Causal Inference, and Social Science

M. Taboga, Markov Chain Monte Carlo (MCMC) methods (2021), Lectures on probability theory and mathematical Statistics. Kindle Direct Publishing. 

T. Wiecki, MCMC Sampling for Dummies (2024), twecki.io
Stan User’s Guide, Reparametrization (n.d.), Stan Documentation

The post Are You Sure Your Posterior Makes Sense? appeared first on Towards Data Science.

]]>
The Basis of Cognitive Complexity: Teaching CNNs to See Connections https://towardsdatascience.com/the-basis-of-cognitive-complexity-teaching-cnns-to-see-connections/ Fri, 11 Apr 2025 05:44:46 +0000 https://towardsdatascience.com/?p=605715 Transforming CNNs: From task-specific learning to abstract generalization

The post The Basis of Cognitive Complexity: Teaching CNNs to See Connections appeared first on Towards Data Science.

]]>

Liberating education consists in acts of cognition, not transferrals of information.

Paulo freire

One of the most heated discussions around artificial intelligence is: What aspects of human learning is it capable of capturing?

Many authors suggest that artificial intelligence models do not possess the same capabilities as humans, especially when it comes to plasticity, flexibility, and adaptation.

One of the aspects that models do not capture are several causal relationships about the external world.

This article discusses these issues:

  • The parallelism between convolutional neural networks (CNNs) and the human visual cortex
  • Limitations of CNNs in understanding causal relations and learning abstract concepts
  • How to make CNNs learn simple causal relations

Is it the same? Is it different?

Convolutional networks (CNNs) [2] are multi-layered neural networks that take images as input and can be used for multiple tasks. One of the most fascinating aspects of CNNs is their inspiration from the human visual cortex [1]:

  • Hierarchical processing. The visual cortex processes images hierarchically, where early visual areas capture simple features (such as edges, lines, and colors) and deeper areas capture more complex features such as shapes, objects, and scenes. CNN, due to its layered structure, captures edges and textures in the early layers, while layers further down capture parts or whole objects.
  • Receptive fields. Neurons in the visual cortex respond to stimuli in a specific local region of the visual field (commonly called receptive fields). As we go deeper, the receptive fields of the neurons widen, allowing more spatial information to be integrated. Thanks to pooling steps, the same happens in CNNs.
  • Feature sharing. Although biological neurons are not identical, similar features are recognized across different parts of the visual field. In CNNs, the various filters scan the entire image, allowing patterns to be recognized regardless of location.
  • Spatial invariance. Humans can recognize objects even when they are moved, scaled, or rotated. CNNs also possess this property.
The relationship between components of the visual system and CNN. Image source: here

These features have made CNNs perform well in visual tasks to the point of superhuman performance:

Russakovsky et al. [22] recently reported that human performance yields a 5.1% top-5 error on the ImageNet dataset. This number is achieved by a human annotator who is well-trained on the validation images to be better aware of the existence of relevant classes. […] Our result (4.94%) exceeds the reported human-level performance. —source [3]

Although CNNs perform better than humans in several tasks, there are still cases where they fail spectacularly. For example, in a 2024 study [4], AI models failed to generalize image classification. State-of-the-art models perform better than humans for objects on upright poses but fail when objects are on unusual poses.

The right label is on the top of the object, and the AI wrong predicted label is below. Image source: here

In conclusion, our results show that (1) humans are still much more robust than most networks at recognizing objects in unusual poses, (2) time is of the essence for such ability to emerge, and (3) even time-limited humans are dissimilar to deep neural networks. —source [4]

In the study [4], they note that humans need time to succeed in a task. Some tasks require not only visual recognition but also abstractive cognition, which requires time.

The generalization abilities that make humans capable come from understanding the laws that govern relations among objects. Humans recognize objects by extrapolating rules and chaining these rules to adapt to new situations. One of the simplest rules is the “same-different relation”: the ability to define whether two objects are the same or different. This ability develops rapidly during infancy and is also importantly associated with language development [5-7]. In addition, some animals such as ducks and chimpanzees also have it [8]. In contrast, learning same-different relations is very difficult for neural networks [9-10].

Example of a same-different task for a CNN. The network should return a label of 1 if the two objects are the same or a label of 0 if they are different. Image source: here

Convolutional networks show difficulty in learning this relationship. Likewise, they fail to learn other types of causal relationships that are simple for humans. Therefore, many researchers have concluded that CNNs lack the inductive bias necessary to be able to learn these relationships.

These negative results do not mean that neural networks are completely incapable of learning same-different relations. Much larger and longer trained models can learn this relation. For example, vision-transformer models pre-trained on ImageNet with contrastive learning can show this ability [12].

Can CNNs learn same-different relationships?

The fact that broad models can learn these kinds of relationships has rekindled interest in CNNs. The same-different relationship is considered among the basic logical operations that make up the foundations for higher-order cognition and reasoning. Showing that shallow CNNs can learn this concept would allow us to experiment with other relationships. Moreover, it will allow models to learn increasingly complex causal relationships. This is an important step in advancing the generalization capabilities of AI.

Previous work suggests that CNNs do not have the architectural inductive biases to be able to learn abstract visual relations. Other authors assume that the problem is in the training paradigm. In general, the classical gradient descent is used to learn a single task or a set of tasks. Given a task t or a set of tasks T, a loss function L is used to optimize the weights φ that should minimize the function L:

Image source from here

This can be viewed as simply the sum of the losses across different tasks (if we have more than one task). Instead, the Model-Agnostic Meta-Learning (MAML) algorithm [13] is designed to search for an optimal point in weight space for a set of related tasks. MAML seeks to find an initial set of weights θ that minimizes the loss function across tasks, facilitating rapid adaptation:

Image source from here

The difference may seem small, but conceptually, this approach is directed toward abstraction and generalization. If there are multiple tasks, traditional training tries to optimize weights for different tasks. MAML tries to identify a set of weights that is optimal for different tasks but at the same time equidistant in the weight space. This starting point θ allows the model to generalize more effectively across different tasks.

Meta-learning initial weights for generalization. Image source from here

Since we now have a method biased toward generalization and abstraction, we can test whether we can make CNNs learn the same-different relationship.

In this study [11], they compared shallow CNNs trained with classic gradient descent and meta-learning on a dataset designed for this report. The dataset consists of 10 different tasks that test for the same-different relationship.

The Same-Different dataset. Image source from here

The authors [11] compare CNNs of 2, 4, or 6 layers trained in a traditional way or with meta-learning, showing several interesting results:

  1. The performance of traditional CNNs shows similar behavior to random guessing.
  2. Meta-learning significantly improves performance, suggesting that the model can learn the same-different relationship. A 2-layer CNN performs little better than chance, but by increasing the depth of the network, performance improves to near-perfect accuracy.
Comparison between traditional training and meta-learning for CNNs. Image source from here

One of the most intriguing results of [11] is that the model can be trained in a leave-one-out way (use 9 tasks and leave one out) and show out-of-distribution generalization capabilities. Thus, the model has learned abstracting behavior that is hardly seen in such a small model (6 layers).

out-of-distribution for same-different classification. Image source from here

Conclusions

Although convolutional networks were inspired by how the human brain processes visual stimuli, they do not capture some of its basic capabilities. This is especially true when it comes to causal relations or abstract concepts. Some of these relationships can be learned from large models only with extensive training. This has led to the assumption that small CNNs cannot learn these relations due to a lack of architecture inductive bias. In recent years, efforts have been made to create new architectures that could have an advantage in learning relational reasoning. Yet most of these architectures fail to learn these kinds of relationships. Intriguingly, this can be overcome through the use of meta-learning.

The advantage of meta-learning is to incentivize more abstractive learning. Meta-learning pressure toward generalization, trying to optimize for all tasks at the same time. To do this, learning more abstract features is favored (low-level features, such as the angles of a particular shape, are not useful for generalization and are disfavored). Meta-learning allows a shallow CNN to learn abstract behavior that would otherwise require many more parameters and training.

The shallow CNNs and same-different relationship are a model for higher cognitive functions. Meta-learning and different forms of training could be useful to improve the reasoning capabilities of the models.

Another thing!

You can look for my other articles on Medium, and you can also connect or reach me on LinkedIn or in Bluesky. Check this repository, which contains weekly updated ML & AI news, or here for other tutorials and here for AI reviews. I am open to collaborations and projects, and you can reach me on LinkedIn.

Reference

Here is the list of the principal references I consulted to write this article, only the first name for an article is cited.

  1. Lindsay, 2020, Convolutional Neural Networks as a Model of the Visual System: Past, Present, and Future, link
  2. Li, 2020, A Survey of Convolutional Neural Networks: Analysis, Applications, and Prospects, link
  3. He, 2015, Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification, link
  4. Ollikka, 2024, A comparison between humans and AI at recognizing objects in unusual poses, link
  5. Premark, 1981, The codes of man and beasts, link
  6. Blote, 1999, Young children’s organizational strategies on a same–different task: A microgenetic study and a training study, link
  7. Lupker, 2015, Is there phonologically based priming in the same-different task? Evidence from Japanese-English bilinguals, link
  8. Gentner, 2021, Learning same and different relations: cross-species comparisons, link
  9. Kim, 2018, Not-so-clevr: learning same–different relations strains feedforward neural networks, link
  10. Puebla, 2021, Can deep convolutional neural networks support relational reasoning in the same-different task? link
  11. Gupta, 2025, Convolutional Neural Networks Can (Meta-)Learn the Same-Different Relation, link
  12. Tartaglini, 2023, Deep Neural Networks Can Learn Generalizable Same-Different Visual Relations, link
  13. Finn, 2017, Model-agnostic meta-learning for fast adaptation of deep networks, link

The post The Basis of Cognitive Complexity: Teaching CNNs to See Connections appeared first on Towards Data Science.

]]>
The Invisible Revolution: How Vectors Are (Re)defining Business Success https://towardsdatascience.com/the-invisible-revolution-how-vectors-are-redefining-business-success/ Thu, 10 Apr 2025 20:52:15 +0000 https://towardsdatascience.com/?p=605712 The hidden force behind AI is powering the next wave of business transformation

The post The Invisible Revolution: How Vectors Are (Re)defining Business Success appeared first on Towards Data Science.

]]>
In a world that focuses more on data, business leaders must understand vector thinking. At first, vectors may appear as complicated as algebra was in school, but they serve as a fundamental building block. Vectors are as essential as algebra for tasks like sharing a bill or computing interest. They underpin our digital systems for decision making, customer engagement, and data protection.

They represent a radically different concept of relationships and patterns. They do not simply divide data into rigid categories. Instead, they offer a dynamic, multidimensional view of the underlying connections. Like “Similar” for two customers may mean more than demographics or purchase histories. It’s their behaviors, preferences, and habits that distinctly align. Such associations can be defined and measured accurately in a vector space. But for many modern businesses, the logic is too complex. So leaders tend to fall back on old, learned, rule-based patterns instead. And back then, fraud detection, for example, still used simple rules on transaction limits. We’ve evolved to recognize patterns and anomalies.

While it might have been common to block transactions that allocate 50% of your credit card limit at once just a few years ago, we are now able to analyze your retailer-specific spend history, look at average baskets of other customers at the very same retailers, and do some slight logic checks such as the physical location of your previous spends.

So a $7,000 transaction for McDonald’s in Dubai might just not happen if you just spent $3 on a bike rental in Amsterdam. Even $20 wouldn’t work since logical vector patterns can rule out the physical distance to be valid. Instead, the $7,000 transaction for your new E-Bike at a retailer near Amsterdam’s city center may just work flawlessly. Welcome to the insight of living in a world managed by vectors.

The danger of ignoring the paradigm of vectors is huge. Not mastering algebra can lead to bad financial decisions. Similarly, not knowing vectors can leave you vulnerable as a business leader. While the average customer may stay unaware of vectors as much as an average passenger in a plane is of aerodynamics, a business leader should be at least aware of what kerosene is and how many seats are to be occupied to break even for a specific flight. You may not need to fully understand the systems you rely on. A basic understanding helps to know when to reach out to the experts. And this is exactly my aim in this little journey into the world of vectors: become aware of the basic principles and know when to ask for more to better steer and manage your business.

In the hushed hallways of research labs and tech companies, a revolution was brewing. It would change how computers understood the world. This revolution has nothing to do with processing power or storage capacity. It was all about teaching machines to understand context, meaning, and nuance in words. This uses mathematical representations called vectors. Before we can appreciate the magnitude of this shift, we first need to understand what it differs from.

Think about the way humans take in information. When we look at a cat, we don’t just process a checklist of components: whiskers, fur, four legs. Instead, our brains work through a network of relationships, contexts, and associations. We know a cat is more like a lion than a bicycle. It’s not from memorizing this fact. Our brains have naturally learned these relationships. It boils down to target_transform_sequence or equivalent. Vector representations let computers consume content in a human-like way. And we ought to understand how and why this is true. It’s as fundamental as knowing algebra in the time of an impending AI revolution.

In this brief jaunt in the vector realm, I will explain how vector-based computing works and why it’s so transformative. The code examples are only examples, so they are just for illustration and have no stand-alone functionality. You don’t have to be an engineer to understand those concepts. All you have to do is follow along, as I walk you through examples with plain language commentary explaining each one step by step, one step at a time. I don’t aim to be a world-class mathematician. I want to make vectors understandable to everyone: business leaders, managers, engineers, musicians, and others.


What are vectors, anyway?

Photo by Pete F on Unsplash

It is not that the vector-based computing journey started recently. Its roots go back to the 1950s with the development of distributed representations in cognitive science. James McClelland and David Rumelhart, among other researchers, theorized that the brain holds concepts not as individual entities. Instead, it holds them as the compiled activity patterns of neural networks. This discovery dominated the path for contemporary vector representations.

The real breakthrough was three things coming together:
The exponential growth in computational power, the development of sophisticated neural network architectures, and the availability of massive datasets for training.

It is the combination of these elements that makes vector-based systems theoretically possible and practically implementable at scale. AI as the mainstream as people got to know it (with the likes of ChatGPT e.a.) is the direct consequence of this.

To better understand, let me put this in context: Conventional computing systems work on symbols —discrete, human-readable symbols and rules. A traditional system, for instance, might represent a customer as a record:

customer = {
    'id': '12345',
    'age': 34,
    'purchase_history': ['electronics', 'books'],
    'risk_level': 'low'
}

This representation may be readable or logical, but it misses subtle patterns and relationships. In contrast, vector representations encode information within high-dimensional space where relationships arise naturally through geometric proximity. That same customer might be represented as a 384-dimensional vector where each one of these dimensions contributes to a rich, nuanced profile. Simple code allows for 2-Dimensional customer data to be transformed into vectors. Let’s take a look at how simple this just is:

from sentence_transformers import SentenceTransformer
import numpy as np

class CustomerVectorization:
    def __init__(self):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        
    def create_customer_vector(self, customer_data):
        """
        Transform customer data into a rich vector representation
        that captures subtle patterns and relationships
        """
        # Combine various customer attributes into a meaningful text representation
        customer_text = f"""
        Customer profile: {customer_data['age']} year old,
        interested in {', '.join(customer_data['purchase_history'])},
        risk level: {customer_data['risk_level']}
        """
        
        # Generate base vector from text description
        base_vector = self.model.encode(customer_text)
        
        # Enrich vector with numerical features
        numerical_features = np.array([
            customer_data['age'] / 100,  # Normalized age
            len(customer_data['purchase_history']) / 10,  # Purchase history length
            self._risk_level_to_numeric(customer_data['risk_level'])
        ])
        
        # Combine text-based and numerical features
        combined_vector = np.concatenate([
            base_vector,
            numerical_features
        ])
        
        return combined_vector
    
    def _risk_level_to_numeric(self, risk_level):
        """Convert categorical risk level to normalized numeric value"""
        risk_mapping = {'low': 0.1, 'medium': 0.5, 'high': 0.9}
        return risk_mapping.get(risk_level.lower(), 0.5)

I trust that this code example has helped demonstrate how easily complex customer data can be encoded into meaningful vectors. The method seems complex at first. But, it is simple. We merge text and numerical data on customers. This gives us rich, info-dense vectors that capture each customer’s essence. What I love most about this technique is its simplicity and flexibility. Similarly to how we encoded age, purchase history, and risk levels here, you could replicate this pattern to capture any other customer attributes that boil down to the relevant base case for your use case. Just recall the credit card spending patterns we described earlier. It’s similar data being turned into vectors to have a meaning far greater than it could ever have it stayed 2-dimensional and would be used for traditional rule-based logics.

What our little code example allowed us to do is having two very suggestive representations in one semantically rich space and one in normalized value space, mapping every record to a line in a graph that has direct comparison properties.

This allows the systems to identify complex patterns and relations that traditional data structures won’t be able to reflect adequately. With the geometric nature of vector spaces, the shape of these structures tells the stories of similarities, differences, and relationships, allowing for an inherently standardized yet flexible representation of complex data. 

But going from here, you will see this structure copied across other applications of vector-based customer analysis: use relevant data, aggregate it in a format we can work with, and meta representation combines heterogeneous data into a common understanding of vectors. Whether it’s recommendation systems, customer segmentation models, or predictive analytics tools, this fundamental approach to thoughtful vectorization will underpin all of it. Thus, this fundamental approach is significant to know and understand even if you consider yourself non-tech and more into the business side.

Just keep in mind — the key is considering what part of your data has meaningful signals and how to encode them in a way that preserves their relationships. It is nothing but following your business logic in another way of thinking other than algebra. A more modern, multi-dimensional way.


The Mathematics of Meaning (Kings and Queens)

Photo by Debbie Fan on Unsplash

All human communication delivers rich networks of meaning that our brains wire to make sense of automatically. These are meanings that we can capture mathematically, using vector-based computing; we can represent words in space so that they are points in a multi-dimensional word space. This geometrical treatment allows us to think in spatial terms about the abstract semantic relations we are interested in, as distances and directions.

For instance, the relationship “King is to Queen as Man is to Woman” is encoded in a vector space in such a way that the direction and distance between the words “King” and “Queen” are similar to those between the words “Man” and “Woman.”

Let’s take a step back to understand why this might be: the key component that makes this system work is word embeddings — numerical representations that encode words as vectors in a dense vector space. These embeddings are derived from examining co-occurrences of words across large snippets of text. Just as we learn that “dog” and “puppy” are related concepts by observing that they occur in similar contexts, embedding algorithms learn to embed these words close to each other in a vector space.

Word embeddings reveal their real power when we look at how they encode analogical relationships. Think about what we know about the relationship between “king” and “queen.” We can tell through intuition that these words are different in gender but share associations related to the palace, authority, and leadership. Through a wonderful property of vector space systems — vector arithmetic — this relationship can be captured mathematically.

One does this beautifully in the classic example:

vector('king') - vector('man') + vector('woman') ≈ vector('queen')

This equation tells us that if we have the vector for “king,” and we subtract out the “man” vector (we remove the concept of “male”), and then we add the “woman” vector (we add the concept of “female”), we get a new point in space very close to that of “queen.” That’s not some mathematical coincidence — it’s based on how the embedding space has arranged the meaning in a sort of structured way.

We can apply this idea of context in Python with pre-trained word embeddings:

import gensim.downloader as api

# Load a pre-trained model that contains word vectors learned from Google News
model = api.load('word2vec-google-news-300')

# Define our analogy words
source_pair = ('king', 'man')
target_word = 'woman'

# Find which word completes the analogy using vector arithmetic
result = model.most_similar(
    positive=[target_word, source_pair[0]], 
    negative=[source_pair[1]], 
    topn=1
)

# Display the result
print(f"{source_pair[0]} is to {source_pair[1]} as {target_word} is to {result[0][0]}")

The structure of this vector space exposes many basic principles:

  1. Semantic similarity is present as spatial proximity. Related words congregate: the neighborhoods of ideas. “Dog,” “puppy,” and “canine” would be one such cluster; meanwhile, “cat,” “kitten,” and “feline” would create another cluster nearby.
  2. Relationships between words become directions in the space. The vector from “man” to “woman” encodes a gender relationship, and other such relationships (for example, “king” to “queen” or “actor” to “actress”) typically point in the same direction.
  3. The magnitude of vectors can carry meaning about word importance or specificity. Common words often have shorter vectors than specialized terms, reflecting their broader, less specific meanings.

Working with relationships between words in this way gave us a geometric encoding of meaning and the mathematical precision needed to reflect the nuances of natural language processing to machines. Instead of treating words as separate symbols, vector-like systems can recognize patterns, make analogies, and even uncover relationships that were never programmed.

To better grasp what was just discussed I took the liberty to have the words we mentioned before (“King, Man, Women”; “Dog, Puppy, Canine”; “Cat, Kitten, Feline”) mapped to a corresponding 2D vector. These vectors numerically represent semantic meaning.

Visualization of the before-mentioned example terms as 2D word embeddings. Showing grouped categories for explanatory purposes. Data is fabricated and axes are simplified for educational purposes.
  • Human-related words have high positive values on both dimensions.
  • Dog-related words have negative x-values and positive y-values.
  • Cat-related words have positive x-values and negative y-values.

Be aware, those values are fabricated by me to illustrate better. As shown in the 2D Space where the vectors are plotted, you can observe groups based on the positions of the dots representing the vectors. The three dog-related words e.g. can be clustered as the “Dog” category etc. etc.

Grasping these basic principles gives us insight into both the capabilities and limitations of modern language AI, such as large language models (LLMs). Though these systems can do amazing analogical and relational gymnastics, they are ultimately cycles of geometric patterns based on the ways that words appear in proximity to one another in a body of text. An elaborate but, by definition, partial reflection of human linguistic comprehension. As such an Llm, since based on vectors, can only generate as output what it has received as input. Although that doesn’t mean it generates only what it has been trained 1:1, we all know about the fantastic hallucination capabilities of LLMs; it means that LLMs, unless specifically instructed, wouldn’t come up with neologisms or new language to describe things. This basic understanding is still lacking for a lot of business leaders that expect LLMs to be miracle machines unknowledgeable about the underlying principles of vectors.


A Tale of Distances, Angles, and Dinner Parties

Photo by OurWhisky Foundation on Unsplash

Now, let’s assume you’re throwing a dinner party and it’s all about Hollywood and the big movies, and you want to seat people based on what they like. You could just calculate “distance” between their preferences (genres, perhaps even hobbies?) and find out who should sit together. But deciding how you measure that distance can be the difference between compelling conversations and annoyed participants. Or awkward silences. And yes, that company party flashback is repeating itself. Sorry for that!

The same is true in the world of vectors. The distance metric defines how “similar” two vectors look, and therefore, ultimately, how well your system performs to predict an outcome.

Euclidean Distance: Straightforward, but Limited

Euclidean distance measures the straight-line distance between two points in space, making it easy to understand:

  • Euclidean distance is fine as long as vectors are physical locations.
  • However, in high-dimensional spaces (like vectors representing user behavior or preferences), this metric often falls short. Differences in scale or magnitude can skew results, focusing on scale over actual similarity.

Example: Two vectors might represent your dinner guests’ preferences for how much streaming services are used:

vec1 = [5, 10, 5]
# Dinner guest A likes action, drama, and comedy as genres equally.

vec2 = [1, 2, 1] 
# Dinner guest B likes the same genres but consumes less streaming overall.

While their preferences align, Euclidean distance would make them seem vastly different because of the disparity in overall activity.

But in higher-dimensional spaces, such as user behavior or textual meaning, Euclidean distance becomes increasingly less informative. It overweights magnitude, which can obscure comparisons. Consider two moviegoers: one has seen 200 action movies, the other has seen 10, but they both like the same genres. Because of their sheer activity level, the second viewer would appear much less similar to the first when using Euclidean distance though all they ever watched is Bruce Willis movies.

Cosine Similarity: Focused on Direction

The cosine similarity method takes a different approach. It focuses on the angle between vectors, not their magnitudes. It’s like comparing the path of two arrows. If they point the same way, they are aligned, no matter their lengths. This shows that it’s perfect for high-dimensional data, where we care about relationships, not scale.

  • If two vectors point in the same direction, they’re considered similar (cosine similarity approx of 1).
  • When opposing (so pointing in opposite directions), they differ (cosine similarity ≈ -1).
  • If they’re perpendicular (at a right angle of 90° to one another), they are unrelated (cosine similarity close to 0).

This normalizing property ensures that the similarity score correctly measures alignment, regardless of how one vector is scaled in comparison to another.

Example: Returning to our streaming preferences, let’s take a look at how our dinner guest’s preferences would look like as vectors:

vec1 = [5, 10, 5]
# Dinner guest A likes action, drama, and comedy as genres equally.

vec2 = [1, 2, 1] 
# Dinner guest B likes the same genres but consumes less streaming overall.

Let us discuss why cosine similarity is really effective in this case. So, when we compute cosine similarity for vec1 [5, 10, 5] and vec2 [1, 2, 1], we’re essentially trying to see the angle between these vectors.

The dot product normalizes the vectors first, dividing each component by the length of the vector. This operation “cancels” the differences in magnitude:

  • So for vec1: Normalization gives us [0.41, 0.82, 0.41] or so.
  • For vec2: Which resolves to [0.41, 0.82, 0.41] after normalization we will also have it.

And now we also understand why these vectors would be considered identical with regard to cosine similarity because their normalized versions are identical!

This tells us that even though dinner guest A views more total content, the proportion they allocate to any given genre perfectly mirrors dinner guest B’s preferences. It’s like saying both your guests dedicate 20% of their time to action, 60% to drama, and 20% to comedy, no matter the total hours viewed.

It’s this normalization that makes cosine similarity particularly effective for high-dimensional data such as text embeddings or user preferences.

When dealing with data of many dimensions (think hundreds or thousands of components of a vector for various features of a movie), it is often the relative significance of each dimension corresponding to the complete profile rather than the absolute values that matter most. Cosine similarity identifies precisely this arrangement of relative importance and is a powerful tool to identify meaningful relationships in complex data.


Hiking up the Euclidian Mountain Trail

Photo by Christian Mikhael on Unsplash

In this part, we will see how different approaches to measuring similarity behave in practice, with a concrete example from the real world and some little code example. Even if you are a non-techie, the code will be easy to understand for you as well. It’s to illustrate the simplicity of it all. No fear!

How about we quickly discuss a 10-mile-long hiking trail? Two friends, Alex and Blake, write trail reviews of the same hike, but each ascribes it a different character:

The trail gained 2,000 feet in elevation over just 2 miles! Easily doable with some high spikes in between!
Alex

and

Beware, we hiked 100 straight feet up in the forest terrain at the spike! Overall, 10 beautiful miles of forest!
Blake

These descriptions can be represented as vectors:

alex_description = [2000, 2]  # [elevation_gain, trail_distance]
blake_description = [100, 10]  # [elevation_gain, trail_distance]

Let’s combine both similarity measures and see what it tells us:

import numpy as np

def cosine_similarity(vec1, vec2):
    """
    Measures how similar the pattern or shape of two descriptions is,
    ignoring differences in scale. Returns 1.0 for perfectly aligned patterns.
    """
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return dot_product / (norm1 * norm2)

def euclidean_distance(vec1, vec2):
    """
    Measures the direct 'as-the-crow-flies' difference between descriptions.
    Smaller numbers mean descriptions are more similar.
    """
    return np.linalg.norm(np.array(vec1) - np.array(vec2))

# Alex focuses on the steep part: 2000ft elevation over 2 miles
alex_description = [2000, 2]  # [elevation_gain, trail_distance]

# Blake describes the whole trail: 100ft average elevation per mile over 10 miles
blake_description = [100, 10]  # [elevation_gain, trail_distance]

# Let's see how different these descriptions appear using each measure
print("Comparing how Alex and Blake described the same trail:")
print("\nEuclidean distance:", euclidean_distance(alex_description, blake_description))
print("(A larger number here suggests very different descriptions)")

print("\nCosine similarity:", cosine_similarity(alex_description, blake_description))
print("(A number close to 1.0 suggests similar patterns)")

# Let's also normalize the vectors to see what cosine similarity is looking at
alex_normalized = alex_description / np.linalg.norm(alex_description)
blake_normalized = blake_description / np.linalg.norm(blake_description)

print("\nAlex's normalized description:", alex_normalized)
print("Blake's normalized description:", blake_normalized)

So now, running this code, something magical happens:

Comparing how Alex and Blake described the same trail:

Euclidean distance: 8.124038404635959
(A larger number here suggests very different descriptions)

Cosine similarity: 0.9486832980505138
(A number close to 1.0 suggests similar patterns)

Alex's normalized description: [0.99975 0.02236]
Blake's normalized description: [0.99503 0.09950]

This output shows why, depending on what you are measuring, the same trail may appear different or similar.

The large Euclidean distance (8.12) suggests these are very different descriptions. It’s understandable that 2000 is a lot different from 100, and 2 is a lot different from 10. It’s like taking the raw difference between these numbers without understanding their meaning.

But the high Cosine similarity (0.95) tells us something more interestingboth descriptions capture a similar pattern.

If we look at the normalized vectors, we can see it, too; both Alex and Blake are describing a trail in which elevation gain is the prominent feature. The first number in each normalized vector (elevation gain) is much larger relative to the second (trail distance). Either that or elevating them both and normalizing based on proportion — not volume — since they both share the same trait defining the trail.

Perfectly true to life: Alex and Blake hiked the same trail but focused on different parts of it when writing their review. Alex focused on the steeper section and described a 100-foot climb, and Blake described the profile of the entire trail, averaged to 200 feet per mile over 10 miles. Cosine similarity identifies these descriptions as variations of the same basic trail pattern, whereas Euclidean distance regards them as completely different trails.

This example highlights the need to select the appropriate similarity measure. Normalizing and taking cosine similarity gives many meaningful correlations that are missed by just taking distances like Euclidean in real use cases.


Real-World Impacts of Metric Choices

Photo by fabio on Unsplash

The metric you pick doesn’t merely change the numbers; it influences the results of complex systems. Here’s how it breaks down in various domains:

  • In Recommendation Engines: When it comes to cosine similarity, we can group users who have the same tastes, even if they are doing different amounts of overall activity. A streaming service could use this to recommend movies that align with a user’s genre preferences, regardless of what is popular among a small subset of very active viewers.
  • In Document Retrieval: When querying a database of documents or research papers, cosine similarity ranks documents according to whether their content is similar in meaning to the user’s query, rather than their text length. This enables systems to retrieve results that are contextually relevant to the query, even though the documents are of a wide range of sizes.
  • In Fraud Detection: Patterns of behavior are often more important than pure numbers. Cosine similarity can be used to detect anomalies in spending habits, as it compares the direction of the transaction vectors — type of merchant, time of day, transaction amount, etc. — rather than the absolute magnitude.

And these differences matter because they give a sense of how systems “think”. Let’s get back to that credit card example one more time: It might, for example, identify a high-value $7,000 transaction for your new E-Bike as suspicious using Euclidean distance — even if that transaction is normal for you given you have an average spent of $20,000 a mont.

A cosine-based system, on the other hand, understands that the transaction is consistent with what the user typically spends their money on, thus avoiding unnecessary false notifications.

But measures like Euclidean distance and cosine similarity are not merely theoretical. They’re the blueprints on which real-world systems stand. Whether it’s recommendation engines or fraud detection, the metrics we choose will directly impact how systems make sense of relationships in data.

Vector Representations in Practice: Industry Transformations

Photo by Louis Reed on Unsplash

This ability for abstraction is what makes vector representations so powerful — they transform complex and abstract field data into concepts that can be scored and actioned. These insights are catalyzing fundamental transformations in business processes, decision-making, and customer value delivery across sectors.

Next, we will explore the solution use cases we are highlighting as concrete examples to see how vectors are freeing up time to solve big problems and creating new opportunities that have a big impact. I picked an industry to show what vector-based approaches to a challenge can achieve, so here is a healthcare example from a clinical setting. Why? Because it matters to us all and is rather easy to relate to than digging into the depths of the finance system, insurance, renewable energy, or chemistry.

Healthcare Spotlight: Pattern Recognition in Complex Medical Data

The healthcare industry poses a perfect storm of challenges that vector representations can uniquely solve. Think of the complexities of patient data: medical histories, genetic information, lifestyle factors, and treatment outcomes all interact in nuanced ways that traditional rule-based systems are incapable of capturing.

At Massachusetts General Hospital, researchers implemented a vector-based early detection system for sepsis, a condition in which every hour of early detection increases the chances of survival by 7.6% (see the full study at pmc.ncbi.nlm.nih.gov/articles/PMC6166236/).

In this new methodology, spontaneous neutrophil velocity profiles (SVP) are used to describe the movement patterns of neutrophils from a drop of blood. We won’t get too medically detailed here, because we’re vector-focused today, but a neutrophil is an immune cell that is kind of a first responder in what the body uses to fight off infections.

The system then encodes each neutrophil’s motion as a vector that captures not just its magnitude (i.e., speed), but also its direction. So they converted biological patterns to high-dimensional vector spaces; thus, they got subtle differences and showed that healthy individuals and sepsis patients exhibited statistically significant differences in movement. Then, these numeric vectors were processed with the help of a Machine Learning model that was trained to detect early signs of sepsis. The result was a diagnostic tool that reached impressive sensitivity (97%) and specificity (98%) to achieve a rapid and accurate identification of this fatal condition — probably with the cosine similarity (the paper doesn’t go into much detail, so this is pure speculation, but it would be the most suitable) that we just learned about a moment ago.

This is just one example of how medical data can be encoded into its vector representations and turned into malleable, actionable insights. This approach made it possible to re-contextualize complex relationships and, along with tread-based machine learning, worked around the limitations of previous diagnostic modalities and proved to be a potent tool for clinicians to save lives. It’s a powerful reminder that Vectors aren’t merely theoretical constructs — they’re practical, life-saving solutions that are powering the future of healthcare as much as your credit card risk detection software and hopefully also your business.


Lead and understand, or face disruption. The naked truth.

Photo by Hunters Race on Unsplash

With all you have read about by now: Think of a decision as small as the decision about the metrics under which data relationships are evaluated. Leaders risk making assumptions that are subtle yet disastrous. You are basically using algebra as a tool, and while getting some result, you cannot know if it is right or not: making leadership decisions without understanding the fundamentals of vectors is like calculating using a calculator but not knowing what formulas you are using.

The good news is this doesn’t mean that business leaders have to become data scientists. Vectors are delightful because, once the core ideas have been grasped, they become very easy to work with. An understanding of a handful of concepts (for example, how vectors encode relationships, why distance metrics are important, and how embedding models function) can fundamentally change how you make high-level decisions. These tools will help you ask better questions, work with technical teams more effectively, and make sound decisions about the systems that will govern your business.

The returns on this small investment in comprehension are huge. There is much talk about personalization. Yet, few organizations use vector-based thinking in their business strategies. It could help them leverage personalization to its full potential. Such an approach would delight customers with tailored experiences and build loyalty. You could innovate in areas like fraud detection and operational efficiency, leveraging subtle patterns in data that traditional ones miss — or perhaps even save lives, as described above. Equally important, you can avoid expensive missteps that happen when leaders defer to others for key decisions without understanding what they mean.

The truth is, vectors are here now, driving a vast majority of all the hyped AI technology behind the scenes to help create the world we navigate in today and tomorrow. Companies that do not adapt their leadership to think in vectors risk falling behind a competitive landscape that becomes ever more data-driven. One who adopts this new paradigm will not just survive but will prosper in an age of never-ending AI innovation.

Now is the moment to act. Start to view the world through vectors. Study their tongue, examine their doctrine, and ask how the new could change your tactics and your lodestars. Much in the way that algebra became an essential tool for writing one’s way through practical life challenges, vectors will soon serve as the literacy of the data age. Actually they do already. It is the future of which the powerful know how to take control. The question is not if vectors will define the next era of businesses; it is whether you are prepared to lead it.

The post The Invisible Revolution: How Vectors Are (Re)defining Business Success appeared first on Towards Data Science.

]]>
How to Measure Real Model Accuracy When Labels Are Noisy https://towardsdatascience.com/how-to-measure-real-model-accuracy-when-labels-are-noisy/ Thu, 10 Apr 2025 19:22:26 +0000 https://towardsdatascience.com/?p=605709 The math behind “true” accuracy and error correlation

The post How to Measure Real Model Accuracy When Labels Are Noisy appeared first on Towards Data Science.

]]>
Ground truth is never perfect. From scientific measurements to human annotations used to train deep learning models, ground truth always has some amount of errors. ImageNet, arguably the most well-curated image dataset has 0.3% errors in human annotations. Then, how can we evaluate predictive models using such erroneous labels?

In this article, we explore how to account for errors in test data labels and estimate a model’s “true” accuracy.

Example: image classification

Let’s say there are 100 images, each containing either a cat or a dog. The images are labeled by human annotators who are known to have 96% accuracy (Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ). If we train an image classifier on some of this data and find that it has 90% accuracy on a hold-out set (Aᵐᵒᵈᵉˡ), what is the “true” accuracy of the model (Aᵗʳᵘᵉ)? A couple of observations first:

  1. Within the 90% of predictions that the model got “right,” some examples may have been incorrectly labeled, meaning both the model and the ground truth are wrong. This artificially inflates the measured accuracy.
  2. Conversely, within the 10% of “incorrect” predictions, some may actually be cases where the model is right and the ground truth label is wrong. This artificially deflates the measured accuracy.

Given these complications, how much can the true accuracy vary?

Range of true accuracy

True accuracy of model for perfectly correlated and perfectly uncorrelated errors of model and label. Figure by author.

The true accuracy of our model depends on how its errors correlate with the errors in the ground truth labels. If our model’s errors perfectly overlap with the ground truth errors (i.e., the model is wrong in exactly the same way as human labelers), its true accuracy is:

Aᵗʳᵘᵉ = 0.90 — (1–0.96) = 86%

Alternatively, if our model is wrong in exactly the opposite way as human labelers (perfect negative correlation), its true accuracy is:

Aᵗʳᵘᵉ = 0.90 + (1–0.96) = 94%

Or more generally:

Aᵗʳᵘᵉ = Aᵐᵒᵈᵉˡ ± (1 — Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ)

It’s important to note that the model’s true accuracy can be both lower and higher than its reported accuracy, depending on the correlation between model errors and ground truth errors.

Probabilistic estimate of true accuracy

In some cases, inaccuracies among labels are randomly spread among the examples and not systematically biased toward certain labels or regions of the feature space. If the model’s inaccuracies are independent of the inaccuracies in the labels, we can derive a more precise estimate of its true accuracy.

When we measure Aᵐᵒᵈᵉˡ (90%), we’re counting cases where the model’s prediction matches the ground truth label. This can happen in two scenarios:

  1. Both model and ground truth are correct. This happens with probability Aᵗʳᵘᵉ × Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ.
  2. Both model and ground truth are wrong (in the same way). This happens with probability (1 — Aᵗʳᵘᵉ) × (1 — Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ).

Under independence, we can express this as:

Aᵐᵒᵈᵉˡ = Aᵗʳᵘᵉ × Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ + (1 — Aᵗʳᵘᵉ) × (1 — Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ)

Rearranging the terms, we get:

Aᵗʳᵘᵉ = (Aᵐᵒᵈᵉˡ + Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ — 1) / (2 × Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ — 1)

In our example, that equals (0.90 + 0.96–1) / (2 × 0.96–1) = 93.5%, which is within the range of 86% to 94% that we derived above.

The independence paradox

Plugging in Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ as 0.96 from our example, we get

Aᵗʳᵘᵉ = (Aᵐᵒᵈᵉˡ — 0.04) / (0.92). Let’s plot this below.

True accuracy as a function of model’s reported accuracy when ground truth accuracy = 96%. Figure by author.

Strange, isn’t it? If we assume that model’s errors are uncorrelated with ground truth errors, then its true accuracy Aᵗʳᵘᵉ is always higher than the 1:1 line when the reported accuracy is > 0.5. This holds true even if we vary Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ:

Model’s “true” accuracy as a function of its reported accuracy and ground truth accuracy. Figure by author.

Error correlation: why models often struggle where humans do

The independence assumption is crucial but often doesn’t hold in practice. If some images of cats are very blurry, or some small dogs look like cats, then both the ground truth and model errors are likely to be correlated. This causes Aᵗʳᵘᵉ to be closer to the lower bound (Aᵐᵒᵈᵉˡ — (1 — Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ)) than the upper bound.

More generally, model errors tend to be correlated with ground truth errors when:

  1. Both humans and models struggle with the same “difficult” examples (e.g., ambiguous images, edge cases)
  2. The model has learned the same biases present in the human labeling process
  3. Certain classes or examples are inherently ambiguous or challenging for any classifier, human or machine
  4. The labels themselves are generated from another model
  5. There are too many classes (and thus too many different ways of being wrong)

Best practices

The true accuracy of a model can differ significantly from its measured accuracy. Understanding this difference is crucial for proper model evaluation, especially in domains where obtaining perfect ground truth is impossible or prohibitively expensive.

When evaluating model performance with imperfect ground truth:

  1. Conduct targeted error analysis: Examine examples where the model disagrees with ground truth to identify potential ground truth errors.
  2. Consider the correlation between errors: If you suspect correlation between model and ground truth errors, the true accuracy is likely closer to the lower bound (Aᵐᵒᵈᵉˡ — (1 — Aᵍʳᵒᵘⁿᵈᵗʳᵘᵗʰ)).
  3. Obtain multiple independent annotations: Having multiple annotators can help estimate ground truth accuracy more reliably.

Conclusion

In summary, we learned that:

  1. The range of possible true accuracy depends on the error rate in the ground truth
  2. When errors are independent, the true accuracy is often higher than measured for models better than random chance
  3. In real-world scenarios, errors are rarely independent, and the true accuracy is likely closer to the lower bound

The post How to Measure Real Model Accuracy When Labels Are Noisy appeared first on Towards Data Science.

]]>
Ivory Tower Notes: The Problem https://towardsdatascience.com/ivory-tower-notes-the-problem/ Thu, 10 Apr 2025 18:48:08 +0000 https://towardsdatascience.com/?p=605707 When a data science problem is "the" problem

The post Ivory Tower Notes: The Problem appeared first on Towards Data Science.

]]>
Did you ever spend months on a Machine Learning project, only to discover you never defined the “correct” problem at the start? If so, or even if not, and you are only starting with the data science or AI field, welcome to my first Ivory Tower Note, where I will address this topic. 


The term “Ivory Tower” is a metaphor for a situation in which someone is isolated from the practical realities of everyday life. In academia, the term often refers to researchers who engage deeply in theoretical pursuits and remain distant from the realities that practitioners face outside academia.

As a former researcher, I wrote a short series of posts from my old Ivory Tower notes — the notes before the LLM era.

Scary, I know. I am writing this to manage expectations and the question, “Why ever did you do things this way?” — “Because no LLM told me how to do otherwise 10+ years ago.”

That’s why my notes contain “legacy” topics such as data mining, machine learning, multi-criteria decision-making, and (sometimes) human interactions, airplanes ✈ and art.

Nonetheless, whenever there is an opportunity, I will map my “old” knowledge to generative AI advances and explain how I applied it to datasets beyond the Ivory Tower.

Welcome to post #1…


How every Machine Learning and AI journey starts

 — It starts with a problem. 

For you, this is usually “the” problem because you need to live with it for months or, in the case of research, years

With “the” problem, I am addressing the business problem you don’t fully understand or know how to solve at first. 

An even worse scenario is when you think you fully understand and know how to solve it quickly. This then creates only more problems that are again only yours to solve. But more about this in the upcoming sections. 

So, what’s “the” problem about?

Causa: It’s mostly about not managing or leveraging resources properly —  workforce, equipment, money, or time. 

Ratio: It’s usually about generating business value, which can span from improved accuracy, increased productivity, cost savings, revenue gains, faster reaction, decision, planning, delivery or turnaround times. 

Veritas: It’s always about finding a solution that relies and is hidden somewhere in the existing dataset. 

Or, more than one dataset that someone labelled as “the one”, and that’s waiting for you to solve the problem. Because datasets follow and are created from technical or business process logs, “there has to be a solution lying somewhere within them.

Ah, if only it were so easy.

Avoiding a different chain of thought again, the point is you will need to:

1 — Understand the problem fully,
2 — If not given, find the dataset “behind” it, and 
3 — Create a methodology to get to the solution that will generate business value from it. 

On this path, you will be tracked and measured, and time will not be on your side to deliver the solution that will solve “the universe equation.” 

That’s why you will need to approach the problem methodologically, drill down to smaller problems first, and focus entirely on them because they are the root cause of the overall problem. 

That’s why it’s good to learn how to…

Think like a Data Scientist.

Returning to the problem itself, let’s imagine that you are a tourist lost somewhere in the big museum, and you want to figure out where you are. What you do next is walk to the closest info map on the floor, which will show your current location. 

At this moment, in front of you, you see something like this: 

Data Science Process. Image by Author, inspired by Microsoft Learn

The next thing you might tell yourself is, “I want to get to Frida Kahlo’s painting.” (Note: These are the insights you want to get.)

Because your goal is to see this one painting that brought you miles away from your home and now sits two floors below, you head straight to the second floor. Beforehand, you memorized the shortest path to reach your goal. (Note: This is the initial data collection and discovery phase.)

However, along the way, you stumble upon some obstacles — the elevator is shut down for renovation, so you have to use the stairs. The museum paintings were reordered just two days ago, and the info plans didn’t reflect the changes, so the path you had in mind to get to the painting is not accurate. 

Then you find yourself wandering around the third floor already, asking quietly again, “How do I get out of this labyrinth and get to my painting faster?

While you don’t know the answer, you ask the museum staff on the third floor to help you out, and you start collecting the new data to get the correct route to your painting. (Note: This is a new data collection and discovery phase.)

Nonetheless, once you get to the second floor, you get lost again, but what you do next is start noticing a pattern in how the paintings have been ordered chronologically and thematically to group the artists whose styles overlap, thus giving you an indication of where to go to find your painting. (Note: This is a modelling phase overlapped with the enrichment phase from the dataset you collected during school days — your art knowledge.)

Finally, after adapting the pattern analysis and recalling the collected inputs on the museum route, you arrive in front of the painting you had been planning to see since booking your flight a few months ago. 

What I described now is how you approach data science and, nowadays, generative AI problems. You always start with the end goal in mind and ask yourself:

“What is the expected outcome I want or need to get from this?”

Then you start planning from this question backwards. The example above started with requesting holidays, booking flights, arranging accommodation, traveling to a destination, buying museum tickets, wandering around in a museum, and then seeing the painting you’ve been reading about for ages. 

Of course, there is more to it, and this process should be approached differently if you need to solve someone else’s problem, which is a bit more complex than locating the painting in the museum. 

In this case, you have to…

Ask the “good” questions.

To do this, let’s define what a good question means [1]: 

A good data science question must be concrete, tractable, and answerable. Your question works well if it naturally points to a feasible approach for your project. If your question is too vague to suggest what data you need, it won’t effectively guide your work.

Formulating good questions keeps you on track so you don’t get lost in the data that should be used to get to the specific problem solution, or you don’t end up solving the wrong problem.

Going into more detail, good questions will help identify gaps in reasoning, avoid faulty premises, and create alternative scenarios in case things do go south (which almost always happens)👇🏼.

Image created by Author after analyzing “Chapter 2. Setting goals by asking good questions” from “Think Like a Data Scientist” book [2]

From the above-presented diagram, you understand how good questions, first and foremost, need to support concrete assumptions. This means they need to be formulated in a way that your premises are clear and ensure they can be tested without mixing up facts with opinions.

Good questions produce answers that move you closer to your goal, whether through confirming hypotheses, providing new insights, or eliminating wrong paths. They are measurable, and with this, they connect to project goals because they are formulated with consideration of what’s possible, valuable, and efficient [2].

Good questions are answerable with available data, considering current data relevance and limitations. 

Last but not least, good questions anticipate obstacles. If something is certain in data science, this is the uncertainty, so having backup plans when things don’t work as expected is important to produce results for your project.

Let’s exemplify this with one use case of an airline company that has a challenge with increasing its fleet availability due to unplanned technical groundings (UTG).

These unexpected maintenance events disrupt flights and cost the company significant money. Because of this, executives decided to react to the problem and call in a data scientist (you) to help them improve aircraft availability.

Now, if this would be the first data science task you ever got, you would maybe start an investigation by asking:

“How can we eliminate all unplanned maintenance events?”

You understand how this question is an example of the wrong or “poor” one because:

  • It is not realistic: It includes every possible defect, both small and big, into one impossible goal of “zero operational interruptions”.
  • It doesn’t hold a measure of success: There’s no concrete metric to show progress, and if you’re not at zero, you’re at “failure.”
  • It is not data-driven: The question didn’t cover which data is recorded before delays occur, and how the aircraft unavailability is measured and reported from it.

So, instead of this vague question, you would probably ask a set of targeted questions:

  1. Which aircraft (sub)system is most critical to flight disruptions?
    (Concrete, specific, answerable) This question narrows down your scope, focusing on only one or two specific (sub) systems affecting most delays.
  2. What constitutes “critical downtime” from an operational perspective?
    (Valuable, ties to business goals) If the airline (or regulatory body) doesn’t define how many minutes of unscheduled downtime matter for schedule disruptions, you might waste effort solving less urgent issues.
  3. Which data sources capture the root causes, and how can we fuse them?
    (Manageable, narrows the scope of the project further) This clarifies which data sources one would need to find the problem solution.

With these sharper questions, you will drill down to the real problem:

  • Not all delays weigh the same in cost or impact. The “correct” data science problem is to predict critical subsystem failures that lead to operationally costly interruptions so maintenance crews can prioritize them.

That’s why…

Defining the problem determines every step after. 

It’s the foundation upon which your data, modelling, and evaluation phases are built 👇🏼.

Image created by Author after analyzing and overlapping different images from “Chapter 2. Setting goals by asking good questions, Think Like a Data Scientist” book [2]

It means you are clarifying the project’s objectives, constraints, and scope; you need to articulate the ultimate goal first and, except for asking “What’s the expected outcome I want or need to get from this?”, ask as well: 

What would success look like and how can we measure it?

From there, drill down to (possible) next-level questions that you (I) have learned from the Ivory Tower days:
 — History questions: “Has anyone tried to solve this before? What happened? What is still missing?”
 —  Context questions: “Who is affected by this problem and how? How are they partially resolving it now? Which sources, methods, and tools are they using now, and can they still be reused in the new models?”
 — Impact Questions: “What happens if we don’t solve this? What changes if we do? Is there a value we can create by default? How much will this approach cost?”
Assumption Questions: “What are we taking for granted that might not be true (especially when it comes to data and stakeholders’ ideas)?”
 — ….

Then, do this in the loop and always “ask, ask again, and don’t stop asking” questions so you can drill down and understand which data and analysis are needed and what the ground problem is. 

This is the evergreen knowledge you can apply nowadays, too, when deciding if your problem is of a predictive or generative nature

(More about this in some other note where I will explain how problematic it is trying to solve the problem with the models that have never seen — or have never been trained on — similar problems before.)

Now, going back to memory lane…

I want to add one important note: I have learned from late nights in the Ivory Tower that no amount of data or data science knowledge can save you if you’re solving the wrong problem and trying to get the solution (answer) from a question that was simply wrong and vague. 

When you have a problem on hand, do not rush into assumptions or building the models without understanding what you need to do (Festina lente)

In addition, prepare yourself for unexpected situations and do a proper investigation with your stakeholders and domain experts because their patience will be limited, too. 

With this, I want to say that the “real art” of being successful in data projects is knowing precisely what the problem is, figuring out if it can be solved in the first place, and then coming up with the “how” part. 

You get there by learning to ask good questions.

To end this narrative, recall how Einstein famously said:  

If I were given one hour to save the planet, I would spend 59 minutes defining the problem and one minute solving it.


Thank you for reading, and stay tuned for the next Ivory Tower note.

If you found this post valuable, feel free to share it with your network. 👏

Connect for more stories on Medium ✍ and LinkedIn 🖇.


References: 

[1] DS4Humans, Backwards Design, accessed: April 5th 2025, https://ds4humans.com/40_in_practice/05_backwards_design.html#defining-a-good-question

[2] Godsey, B. (2017), Think Like a Data Scientist: Tackle the data science process step-by-step, Manning Publications.

The post Ivory Tower Notes: The Problem appeared first on Towards Data Science.

]]>
Deb8flow: Orchestrating Autonomous AI Debates with LangGraph and GPT-4o https://towardsdatascience.com/deb8flow-orchestrating-autonomous-ai-debates-with-langgraph-and-gpt-4o/ Thu, 10 Apr 2025 05:14:56 +0000 https://towardsdatascience.com/?p=605704 Inside Deb8flow: Real-time AI debates with LangGraph and GPT-4o

The post Deb8flow: Orchestrating Autonomous AI Debates with LangGraph and GPT-4o appeared first on Towards Data Science.

]]>
Introduction

I’ve always been fascinated by debates—the strategic framing, the sharp retorts, and the carefully timed comebacks. Debates aren’t just entertaining; they’re structured battles of ideas, driven by logic and evidence. Recently, I started wondering: could we replicate that dynamic using AI agents—having them debate each other autonomously, complete with real-time fact-checking and moderation? The result was Deb8flow, an autonomous AI debating environment powered by LangGraph, OpenAI’s GPT-4o model, and the new integrated Web Search feature.

In Deb8flow, two agents—Pro and Con—square off on a given topic while a Moderator manages turn-taking. A dedicated Fact Checker reviews every claim in real time using GPT-4o’s new browsing capabilities, and a final Judge evaluates the arguments for quality and coherence. If an agent repeatedly makes factual errors, they’re automatically disqualified—ensuring the debate stays grounded in truth.

This article offers an in-depth look at the advanced architecture and dynamic workflows that power autonomous AI debates. I’ll walk you through how Deb8flow’s modular design leverages LangGraph’s state management and conditional routing, alongside GPT-4o’s capabilities.

Even if you’re new to AI agents or LangGraph (see resources [1] and [2] for primers), I’ll explain the key concepts clearly. And if you’d like to explore further, the full project is available on GitHub: iason-solomos/Deb8flow.

Ready to see how AI agents can debate autonomously in practice?

Let’s dive in.

High-Level Overview: Autonomous Debates with Multiple Agents

In Deb8flow, we orchestrate a formal debate between two AI agents – one arguing Pro and one Con – complete with a Moderator, a Fact Checker, and a final Judge. The debate unfolds autonomously, with each agent playing a role in a structured format.

At its core, Deb8flow is a LangGraph-powered agent system, built atop LangChain, using GPT-4o to power each role—Pro, Con, Judge, and beyond. We use GPT-4o’s preview model with browsing capabilities to enable real-time fact-checking. In essence, the Pro and Con agents debate; after each statement, a fact-checker agent uses GPT-4o’s web search to catch any hallucinations or inaccuracies in that statement in real time.​ The debate only continues once the statement is verified. The whole process is coordinated by a LangGraph-defined workflow that ensures proper turn-taking and conditional logic.


High-level debate flow graph. Each rectangle is an agent node (Pro/Con debaters, Fact Checker, Judge, etc.), and diamonds are control nodes (Moderator and a router after fact-checking). Solid arrows denote the normal progression, while dashed arrows indicate retries if a claim fails fact-check. The Judge node outputs the final verdict, then the workflow ends.
Image generated by the author with DALL-E

The debate workflow goes through these stages:

  • Topic Generation: A Topic Generator agent produces a nuanced, debatable topic for the session (e.g. “Should AI be used in classroom education?”).
  • Opening: The Pro Argument Agent makes an opening statement in favor of the topic, kicking off the debate.
  • Rebuttal: The Debate Moderator then gives the floor to the Con Argument agent, who rebuts the Pro’s opening statement.
  • Counter: The Moderator gives the floor back to the Pro agent, who counters the Con agent’s points.
  • Closing: The Moderator switches the floor to the Con agent one last time for a closing argument.
  • Judgment: Finally, the Judge agent reviews the full debate history and evaluates both sides based on argument quality, clarity, and persuasiveness. The most convincing side wins.

After every single speech, the Fact Checker agent steps in to verify the factual accuracy of that statement​. If a debater’s claim doesn’t hold up (e.g. cites a wrong statistic or “hallucinates” a fact), the workflow triggers a retry: the speaker has to correct or modify their statement. (If either debater accumulates 3 fact-check failures, they are automatically disqualified for repeatedly spreading inaccuracies, and their opponent wins by default.) This mechanism keeps our AI debaters honest and grounded in reality!

Prerequisites and Setup

Before diving into the code, make sure you have the following in place:

  • Python 3.12+ installed.
  • An OpenAI API key with access to the GPT-4o model. You can create your own API key here: https://platform.openai.com/settings/organization/api-keys
  • Project Code: Clone the Deb8flow repository from GitHub (git clone https://github.com/iason-solomos/Deb8flow.git). The repo includes a requirements.txt for all required packages. Key dependencies include LangChain/LangGraph (for building the agent graph) and the OpenAI Python client.
  • Install Dependencies: In your project directory, run: pip install -r requirements.txt to install the necessary libraries.
  • Create a .env file in the project root to hold your OpenAI API credentials. It should be of the form: OPENAI_API_KEY_GPT4O = "sk-…"
  • You can also at any time check out the README file: https://github.com/iason-solomos/Deb8flow if you simply want to run the finished app.

Once dependencies are installed and the environment variable is set, you should be ready to run the app. The project structure is organized for clarity:

Deb8flow/
├── configurations/
│ ├── debate_constants.py
│ └── llm_config.py
├── nodes/
│ ├── base_component.py
│ ├── topic_generator_node.py
│ ├── pro_debater_node.py
│ ├── con_debater_node.py
│ ├── debate_moderator_node.py
│ ├── fact_checker_node.py
│ ├── fact_check_router_node.py
│ └── judge_node.py
├── prompts/
│ ├── topic_generator_prompts.py
│ ├── pro_debater_prompts.py
│ ├── con_debater_prompts.py
│ └── … (prompts for other agents)
├── tests/ (contains unit and whole workflow tests)
└── debate_workflow.py

A quick tour of this structure:

configurations/ holds constant definitions and LLM configuration classes.

nodes/ contains the implementation of each agent or functional node in the debate (each of these is a module defining one agent’s behavior).

prompts/ stores the prompt templates for the language model (so each agent knows how to prompt GPT-4o for its specific task).

debate_workflow.py ties everything together by defining the LangGraph workflow (the graph of nodes and transitions).

debate_state.py defines the shared data structure that the agents will be using on each run.

tests/ includes some basic tests and example runs to help you verify everything is working.

Under the Hood: State Management and Workflow Setup

To coordinate a complex multi-turn debate, we need a shared state and a well-defined flow. We’ll start by looking at how Deb8flow defines the debate state and constants, and then see how the LangGraph workflow is constructed.

Defining the Debate State Schema (debate_state.py)

Deb8flow uses a shared state (https://langchain-ai.github.io/langgraph/concepts/low_level/#state ) in the form of a Python TypedDict that all agents can read from and update. This state tracks the debate’s progress and context – things like the topic, the history of messages, whose turn it is, etc. By centralizing this information, each agent node can make decisions based on the current state of the debate.

Link: debate_state.py

from typing import TypedDict, List, Dict, Literal


DebateStage = Literal["opening", "rebuttal", "counter", "final_argument"]

class DebateMessage(TypedDict):
    speaker: str  # e.g. pro or con
    content: str  # The message each speaker produced
    validated: bool  # Whether the FactChecker ok’d this message
    stage: DebateStage # The stage of the debate when this message was produced

class DebateState(TypedDict):
    debate_topic: str
    positions: Dict[str, str]
    messages: List[DebateMessage]
    opening_statement_pro_agent: str
    stage: str  # "opening", "rebuttal", "counter", "final_argument"
    speaker: str  # "pro" or "con"
    times_pro_fact_checked: int # The number of times the pro agent has been fact-checked. If it reaches 3, the pro agent is disqualified.
    times_con_fact_checked: int # The number of times the con agent has been fact-checked. If it reaches 3, the con agent is disqualified.

Key fields that we need to have in the DebateState include:

  • debate_topic (str): The topic being debated.
  • messages (List[DebateMessage]): A list of all messages exchanged so far. Each message is a dictionary with fields for speaker (e.g. "pro" or "con" or "fact_checker"), the message content (text), a validated flag (whether it passed fact-check), and the stage of the debate when it was produced.
  • stage (str): The current debate stage (one of "opening", "rebuttal", "counter", "final_argument").
  • speaker (str): Whose turn it is currently ("pro" or "con").
  • times_pro_fact_checked / times_con_fact_checked (int): Counters for how many times each side has been caught with a false claim. (In our rules, if a debater fails fact-check 3 times, they could be disqualified or automatically lose.)
  • positions (Dict[str, str]): (Optional) A mapping of each side’s general stance (e.g., "pro": "In favor of the topic").

By structuring the debate’s state, agents find it easy to access the conversation history or check the current stage, and the control logic can update the state between turns. The state is essentially the memory of the debate.

Constants and Configuration

To avoid “magic strings” scattered in the code, we define some constants in debate_constants.py. For example, constants for stage names (STAGE_OPENING = "opening", etc.), speaker identifiers (SPEAKER_PRO = "pro", SPEAKER_CON = "con", etc.), and node names (NODE_PRO_DEBATER = "pro_debater_node", etc.). These make the code easier to maintain and read.

debate_constants.py:

# Stage names
STAGE_OPENING = "opening"
STAGE_REBUTTAL = "rebuttal"
STAGE_COUNTER = "counter"
STAGE_FINAL_ARGUMENT = "final_argument"
STAGE_END = "end"

# Speakers
SPEAKER_PRO = "pro"
SPEAKER_CON = "con"
SPEAKER_JUDGE = "judge"

# Node names
NODE_PRO_DEBATER = "pro_debater_node"
NODE_CON_DEBATER = "con_debater_node"
NODE_DEBATE_MODERATOR = "debate_moderator_node"
NODE_JUDGE = "judge_node"

We also set up LLM configuration in llm_config.py. Here, we define classes for OpenAI or Azure OpenAI configs and then create a dictionary llm_config_map mapping model names to their config. For instance, we map "gpt-4o" to an OpenAILLMConfig that holds the model name and API key. This way, whenever we need to initialize a GPT-4o agent, we can just do llm_config_map["gpt-4o"] to get the right config. All our main agents (debaters, topic generator, judge) use this same GPT-4o configuration.

import os
from dataclasses import dataclass
from typing import Union

@dataclass
class OpenAILLMConfig:
    """
    A data class to store configuration details for OpenAI models.

    Attributes:
        model_name (str): The name of the OpenAI model to use.
        openai_api_key (str): The API key for authenticating with the OpenAI service.
    """
    model_name: str
    openai_api_key: str


llm_config_map = {
    "gpt-4o": OpenAILLMConfig(
        model_name="gpt-4o",
        openai_api_key=os.getenv("OPENAI_API_KEY_GPT4O"),
    )
}

Building the LangGraph Workflow (debate_workflow.py)

With state and configs in place, we construct the debate workflow graph. LangGraph’s StateGraph is the backbone that connects all our agent nodes in the order they should execute. Here’s how we set it up:

class DebateWorkflow:

    def _initialize_workflow(self) -> StateGraph:
        workflow = StateGraph(DebateState)
        # Nodes
        workflow.add_node("generate_topic_node", GenerateTopicNode(llm_config_map["gpt-4o"]))
        workflow.add_node("pro_debater_node", ProDebaterNode(llm_config_map["gpt-4o"]))
        workflow.add_node("con_debater_node", ConDebaterNode(llm_config_map["gpt-4o"]))
        workflow.add_node("fact_check_node", FactCheckNode())
        workflow.add_node("fact_check_router_node", FactCheckRouterNode())
        workflow.add_node("debate_moderator_node", DebateModeratorNode())
        workflow.add_node("judge_node", JudgeNode(llm_config_map["gpt-4o"]))

        # Entry point
        workflow.set_entry_point("generate_topic_node")

        # Flow
        workflow.add_edge("generate_topic_node", "pro_debater_node")
        workflow.add_edge("pro_debater_node", "fact_check_node")
        workflow.add_edge("con_debater_node", "fact_check_node")
        workflow.add_edge("fact_check_node", "fact_check_router_node")
        workflow.add_edge("judge_node", END)
        return workflow



    async def run(self):
        workflow = self._initialize_workflow()
        graph = workflow.compile()
        # graph.get_graph().draw_mermaid_png(output_file_path="workflow_graph.png")
        initial_state = {
            "topic": "",
            "positions": {}
        }
        final_state = await graph.ainvoke(initial_state, config={"recursion_limit": 50})
        return final_state

Let’s break down what’s happening:

  • We initialize a new StateGraph with our DebateState type as the state schema.
  • We add each node (agent) to the graph with a name. For nodes that need an LLM, we pass in the GPT-4o config. For example, "pro_debater_node" is added as ProDebaterNode(llm_config_map["gpt-4o"]), meaning the Pro debater agent will use GPT-4o as its underlying model.
  • We set the entry point of the graph to "generate_topic_node". This means the first step of the workflow is to generate a debate topic.
  • Then we add directed edges to connect nodes. The edges above encode the primary sequence: topic -> pro’s turn -> fact-check -> (then a routing decision) -> … eventually -> judge -> END. We don’t connect the Moderator or Fact Check Router with static edges, since these nodes use dynamic commands to redirect the flow. The final edge connects the judge to an END marker to terminate the graph.

When the workflow runs, control will pass along these edges in order, but whenever we hit a router or moderator node, that node will output a command telling the graph which node to go to next (overriding the default edge). This is how we create conditional loops: the fact_check_router_node might send us back to a debater node for a retry, instead of following a straight line. LangGraph supports this by allowing nodes to return a special Command object with goto instructions.

In summary, at a high level we’ve defined an agentic workflow: a graph of autonomous agents where control can branch and loop based on the agents’ outputs. Now, let’s explore what each of these agent nodes actually does.

Agent Nodes Breakdown

Each stage or role in the debate is encapsulated in a node (agent). In LangGraph, nodes are often simple functions, but I wanted a more object-oriented approach for clarity and reusability. So in Deb8flow, every node is a class with a __call__ method. All the main agent classes inherit from a common BaseComponent for shared functionality. This design makes the system modular: we can easily swap out or extend agents by modifying their class definitions, and each agent class is responsible for its piece of the workflow.

Let’s go through the key agents one by one.

BaseComponent – A Reusable Agent Base Class

Most of our agent nodes (like the debaters and judge) share common needs: they use an LLM to generate output, they might need to retry on errors, and they should track token usage. The BaseComponent class (defined in <a href="https://github.com/iason-solomos/Deb8flow/blob/main/nodes/base_component.py">nodes/base_component.py</a>) provides these common features so we don’t repeat code.

class BaseComponent:
    """
    A foundational class for managing LLM-based workflows with token tracking.
    Can handle both Azure OpenAI (AzureChatOpenAI) and OpenAI (ChatOpenAI).
    """

    def __init__(
        self,
        llm_config: Optional[LLMConfig] = None,
        temperature: float = 0.0,
        max_retries: int = 5,
    ):
        """
        Initializes the BaseComponent with optional LLM configuration and temperature.

        Args:
            llm_config (Optional[LLMConfig]): Configuration for either Azure or OpenAI.
            temperature (float): Controls the randomness of LLM outputs. Defaults to 0.0.
            max_retries (int): How many times to retry on 429 errors.
        """
        logger = logging.getLogger(self.__class__.__name__)
        tracer = trace.get_tracer(__name__, tracer_provider=get_tracer_provider())

        self.logger = logger
        self.tracer = tracer
        self.llm: Optional[ChatOpenAI] = None
        self.output_parser: Optional[StrOutputParser] = None
        self.state: Optional[DebateState] = None
        self.prompt_template: Optional[ChatPromptTemplate] = None
        self.chain: Optional[RunnableSequence] = None
        self.documents: Optional[List] = None
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.max_retries = max_retries

        if llm_config is not None:
            self.llm = self._init_llm(llm_config, temperature)
            self.output_parser = StrOutputParser()

    def _init_llm(self, config: LLMConfig, temperature: float):
        """
        Initializes an LLM instance for either Azure OpenAI or standard OpenAI.
        """
        if isinstance(config, AzureOpenAILLMConfig):
            # If it's Azure, use the AzureChatOpenAI class
            return AzureChatOpenAI(
                deployment_name=config.deployment_name,
                azure_endpoint=config.azure_endpoint,
                openai_api_version=config.openai_api_version,
                openai_api_key=config.openai_api_key,
                temperature=temperature,
            )
        elif isinstance(config, OpenAILLMConfig):
            # If it's standard OpenAI, use the ChatOpenAI class
            return ChatOpenAI(
                model_name=config.model_name,
                openai_api_key=config.openai_api_key,
                temperature=temperature,
            )
        else:
            raise ValueError("Unsupported LLMConfig type.")

    def validate_initialization(self) -> None:
        """
        Ensures we have an LLM and an output parser.
        """
        if not self.llm:
            raise ValueError("LLM is not initialized. Ensure `llm_config` is provided.")
        if not self.output_parser:
            raise ValueError("Output parser is not initialized.")

    def execute_chain(self, inputs: Any) -> Any:
        """
        Executes the LLM chain, tracks token usage, and retries on 429 errors.
        """
        if not self.chain:
            raise ValueError("No chain is initialized for execution.")

        retry_wait = 1  # Initial wait time in seconds

        for attempt in range(self.max_retries):
            try:
                with get_openai_callback() as cb:
                    result = self.chain.invoke(inputs)
                    self.logger.info("Prompt Token usage: %s", cb.prompt_tokens)
                    self.logger.info("Completion Token usage: %s", cb.completion_tokens)
                    self.prompt_tokens = cb.prompt_tokens
                    self.completion_tokens = cb.completion_tokens

                return result

            except Exception as e:
                # If the error mentions 429, do exponential backoff and retry
                if "429" in str(e):
                    self.logger.warning(
                        f"Rate limit reached. Retrying in {retry_wait} seconds... "
                        f"(Attempt {attempt + 1}/{self.max_retries})"
                    )
                    time.sleep(retry_wait)
                    retry_wait *= 2
                else:
                    self.logger.error(f"Unexpected error: {str(e)}")
                    raise e

        raise Exception("API request failed after maximum number of retries")

    def create_chain(
        self, system_template: str, human_template: str
    ) -> RunnableSequence:
        """
        Creates a chain for unstructured outputs.
        """
        self.validate_initialization()
        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", system_template),
                ("human", human_template),
            ]
        )
        self.chain = self.prompt_template | self.llm | self.output_parser
        return self.chain

    def create_structured_output_chain(
        self, system_template: str, human_template: str, output_model: Type[BaseModel]
    ) -> RunnableSequence:
        """
        Creates a chain that yields structured outputs (parsed into a Pydantic model).
        """
        self.validate_initialization()
        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", system_template),
                ("human", human_template),
            ]
        )
        self.chain = self.prompt_template | self.llm.with_structured_output(output_model)
        return self.chain

    def build_return_with_tokens(self, node_specific_data: dict) -> dict:
        """
        Convenience method to add token usage info into the return values.
        """
        return {
            **node_specific_data,
            "prompt_tokens": self.prompt_tokens,
            "completion_tokens": self.completion_tokens,
        }

    def __call__(self, state: DebateState) -> None:
        """
        Updates the node's local copy of the state.
        """
        self.state = state
        for key, value in state.items():
            setattr(self, key, value)

Key features of BaseComponent:

  • It stores an LLM client (e.g. an OpenAI ChatOpenAI instance) initialized with a given model and API key, as well as an output parser.
  • It provides a method create_chain(system_template, human_template) which sets up a LangChain prompt chain (a RunnableSequence) combining a system prompt and a human prompt. This chain is what actually generates outputs when run.
  • It has an execute_chain(inputs) method that invokes the chain and includes logic to retry if the OpenAI API returns a rate-limit error (HTTP 429). This is done with exponential backoff up to a max_retries count.
  • It keeps track of token usage (prompt tokens and completion tokens) for logging or analysis.
  • The __call__ method of BaseComponent (which each subclass will call via super().__call__(state)) can perform any setup needed before the node’s main logic runs (like ensuring the LLM is initialized).

By building on BaseComponent, each agent class can focus on its unique logic (like what prompt to use and how to handle the state), while inheriting the heavy lifting of interacting with GPT-4o reliably.

Topic Generator Agent (GenerateTopicNode)

The Topic Generator (topic_generator_node.py) is the first agent in the graph. Its job is to come up with a debatable topic for the session. We give it a prompt that instructs it to output a nuanced topic that could reasonably have a pro and con side.

This agent inherits from BaseComponent and uses a prompt chain (system + human prompt) to generate one item of text – the debate topic. When called, it executes the chain (with no special input, just using the prompt) and gets back a topic_text. It then updates the state with:

  • debate_topic: the generated topic (stripped of any extra whitespace),
  • positions: a dictionary assigning the pro and con stances (by default we use "In favor of the topic" and "Against the topic"),
  • stage: set to "opening",
  • speaker: set to "pro" (so the Pro side will speak first).

In code, the return might look like:

return {
    "debate_topic": debate_topic,
    "positions": positions,
    "stage": "opening",
    "speaker": first_speaker  # "pro"
}

Here are the prompts for the topic generator:

SYSTEM_PROMPT = """\
You are a brainstorming AI that suggests debate topics.
You will provide a single, interesting or timely topic that can have two opposing views.
"""

HUMAN_PROMPT = """\
Please suggest one debate topic for two AI agents to discuss.
For example, it could be about technology, politics, philosophy, or any interesting domain.
Just provide the topic in a concise sentence.
"""

Then we pass these prompts in the constructor of the class itself.

class GenerateTopicNode(BaseComponent):
    def __init__(self, llm_config, temperature: float = 0.7):
        super().__init__(llm_config, temperature)
        # Create the prompt chain.
        self.chain: RunnableSequence = self.create_chain(
            system_template=SYSTEM_PROMPT,
            human_template=HUMAN_PROMPT
        )

    def __call__(self, state: DebateState) -> Dict[str, str]:
        """
        Generates a debate topic and assigns positions to the two debaters.
        """
        super().__call__(state)

        topic_text = self.execute_chain({})

        # Store the topic and assign stances in the DebateState
        debate_topic = topic_text.strip()
        positions = {
            "pro": "In favor of the topic",
            "con": "Against the topic"
        }

        
        first_speaker = "pro"
        self.logger.info("Welcome to our debate panel! Today's debate topic is: %s", debate_topic)
        return {
            "debate_topic": debate_topic,
            "positions": positions,
            "stage": "opening",
            "speaker": first_speaker
        }

It’s a pattern we will repeat for all classes except for those not using LLMs and the fact checker.

Now we can implement the 2 stars of the show, the Pro and Con argument agents!

Debater Agents (Pro and Con)

Link: pro_debater_node.py

The two debater agents are very similar in structure, but each uses different prompt templates tailored to their role (pro vs con) and the stage of the debate.

The Pro debater, for example, has to handle an opening statement and a counter-argument (countering the Con’s rebuttal). We also need logic for retries in case a statement fails fact-check. In code, the ProDebater class sets up multiple prompt chains:

  • opening_chain and an opening_retry_chain (using slightly different human prompts – the retry prompt might instruct it to try again without repeating any factually dubious claims).
  • counter_chain and counter_retry_chain for the counter-argument stage.
class ProDebaterNode(BaseComponent):
    def __init__(self, llm_config, temperature: float = 0.7):
        super().__init__(llm_config, temperature)
        self.opening_chain = self.create_chain(SYSTEM_PROMPT, OPENING_HUMAN_PROMPT)
        self.opening_retry_chain = self.create_chain(SYSTEM_PROMPT, OPENING_RETRY_HUMAN_PROMPT)
        self.counter_chain = self.create_chain(SYSTEM_PROMPT, COUNTER_HUMAN_PROMPT)
        self.counter_retry_chain = self.create_chain(SYSTEM_PROMPT, COUNTER_RETRY_HUMAN_PROMPT)

    def __call__(self, state: DebateState) -> Dict[str, Any]:
        super().__call__(state)

        debate_topic = state.get("debate_topic")
        messages = state.get("messages", [])
        stage = state.get("stage")
        speaker = state.get("speaker")

        # Check if retrying (last message was by pro and not validated)
        last_msg = messages[-1] if messages else None
        retrying = last_msg and last_msg["speaker"] == SPEAKER_PRO and not last_msg["validated"]

        if stage == STAGE_OPENING and speaker == SPEAKER_PRO:
            chain = self.opening_retry_chain if retrying else self.opening_chain # select which chain we are triggering: the normal one or the fact-cehcked one
            result = chain.invoke({
                "debate_topic": debate_topic
            })
        elif stage == STAGE_COUNTER and speaker == SPEAKER_PRO:
            opponent_msg = self._get_last_message_by(SPEAKER_CON, messages)
            debate_history = get_debate_history(messages)
            chain = self.counter_retry_chain if retrying else self.counter_chain
            result = chain.invoke({
                "debate_topic": debate_topic,
                "opponent_statement": opponent_msg,
                "debate_history": debate_history
            })
        else:
            raise ValueError(f"Unknown turn for ProDebater: stage={stage}, speaker={speaker}")
        new_message = create_debate_message(speaker=SPEAKER_PRO, content=result, stage=stage)
        self.logger.info("Speaker: %s, Stage: %s, Retry: %s\nMessage:\n%s", speaker, stage, retrying, result)
        return {
            "messages": messages + [new_message]
        }

    def _get_last_message_by(self, speaker_prefix, messages):
        for m in reversed(messages):
            if m.get("speaker") == speaker_prefix:
                return m["content"]
        return ""

When the ProDebater’s __call__ runs, it looks at the current stage and speaker in the state to decide what to do:

  • If it’s the opening stage and the speaker is “pro”, it uses the opening_chain to generate an opening argument. If the last message from Pro was marked invalid (not validated), it knows this is a retry, so it would use the opening_retry_chain instead.
  • If it’s the counter stage and speaker is “pro”, it generates a counter-argument to whatever the opponent (Con) just said. It will fetch the last message by the Con from the messages history, and feed that into the prompt (so that the Pro can directly counter it). Again, if the last Pro message was invalid, it would switch to the retry chain.

After generating its argument, the Debater agent creates a new message entry (with speaker="pro", the content text, validated=False initially, and the stage) and appends it to the state’s message list. That becomes the output of the node (LangGraph will merge this partial state update into the global state).

The Con Debater agent mirrors this logic for its stages:

It similarly appends its message to the state.

It has a rebuttal and closing argument (final argument) stage, each with a normal and a retry chain.

It checks if it’s the rebuttal stage (speaker “con”) or final argument stage (speaker “con”) and invokes the appropriate chain, possibly using the last Pro message for context when rebutting.

con_debater_node.py

By using class-based implementation, our debaters’ code is easier to maintain. We can clearly separate what the Pro does vs what the Con does, even if they share structure. Also, by encapsulating prompt chains inside the class, each debater can manage multiple possible outputs (regular vs retry) cleanly.

Prompt design: The actual prompts (in prompts/pro_debater_prompts.py and con_debater_prompts.py) guide the GPT-4o model to take on a persona (“You are a debater arguing for/against the topic…”) and produce the argument. They also instruct the model to keep statements factual and logical. If a fact check fails, the retry prompt may say something like: “Your previous statement had an unverified claim. Revise your argument to be factually correct while maintaining your position.” – encouraging the model to correct itself.

With this, our AI debaters can engage in a multi-turn duel, and even recover from factual missteps.

Fact Checker Agent (FactCheckNode)

After each debater speaks, the Fact Checker agent swoops in to verify their claims. This agent is implemented in <a href="https://github.com/iason-solomos/Deb8flow/blob/main/nodes/fact_checker_node.py">fact_checker_node.py</a>, and interestingly, it uses the GPT-4o model’s browsing ability rather than our own custom prompts. Essentially, we delegate the fact-checking to OpenAI’s GPT-4 with web search.

How does this work? The OpenAI Python client for GPT-4 (with browsing) allows us to send a user message and get a structured response. In FactCheckNode.__call__, we do something like:

completion = self.client.beta.chat.completions.parse(
            model="gpt-4o-search-preview",
            web_search_options={},
            messages=[{
                "role": "user",
                "content": (
                        f"Consider the following statement from a debate. "
                        f"If the statement contains numbers, or figures from studies, fact-check it online.\n\n"
                        f"Statement:\n\"{claim}\"\n\n"
                        f"Reply clearly whether any numbers or studies might be inaccurate or hallucinated, and why."
                        f"\n"
                        f"If the statement doesn't contain references to studies or numbers cited, don't go online to fact-check, and just consider it successfully fact-checked, with a 'yes' score.\n\n"
                )
            }],
            response_format=FactCheck
        )

If the result is “yes” (meaning the claim seems truthful or at least not factually wrong), the Fact Checker will mark the last message’s validated field as True in the state, and output {"validated": True} with no further changes. This signals that the debate can continue normally.

If the result is “no” (meaning it found the claim to be incorrect or dubious), the Fact Checker will append a new message to the state with speaker="fact_checker" describing the finding (or we could simply mark it, but providing a brief note like “(Fact Checker: The statistic cited could not be verified.)” can be useful). It will also set validated: False and increment a counter for whichever side made the claim. The output state from this node includes validated: False and an updated times_pro_fact_checked or times_con_fact_checked count.

We also use a Pydantic BaseModel to control the output of the LLM:

class FactCheck(BaseModel):
    """
    Pydantic model for the fact checking the claims made by debaters.

    Attributes:
        binary_score (str): 'yes' if the claim is verifiable and truthful, 'no' otherwise.
    """

    binary_score: str = Field(
        description="Indicates if the claim is verifiable and truthful. 'yes' or 'no'."
    )
    justification: str = Field(
        description="Explanation of the reasoning behind the score."
    )

Debate Moderator Agent (DebateModeratorNode)

The Debate Moderator is the conductor of the debate. Instead of producing lengthy text, this agent’s job is to manage turn-taking and stage progression. In the workflow, after a statement is validated by the Fact Checker, control passes to the Moderator node. The Moderator then issues a Command that updates the state for the next turn and directs the flow to the appropriate next agent.

The logic in DebateModeratorNode.__call__ (see <a href="https://github.com/iason-solomos/Deb8flow/blob/main/nodes/debate_moderator_node.py">nodes/debate_moderator_node.py</a>) goes roughly like this:

if stage == STAGE_OPENING and speaker == SPEAKER_PRO:
            return Command(
                update={"stage": STAGE_REBUTTAL, "speaker": SPEAKER_CON},
                goto=NODE_CON_DEBATER
            )
        elif stage == STAGE_REBUTTAL and speaker == SPEAKER_CON:
            return Command(
                update={"stage": STAGE_COUNTER, "speaker": SPEAKER_PRO},
                goto=NODE_PRO_DEBATER
            )
        elif stage == STAGE_COUNTER and speaker == SPEAKER_PRO:
            return Command(
                update={"stage": STAGE_FINAL_ARGUMENT, "speaker": SPEAKER_CON},
                goto=NODE_CON_DEBATER
            )
        elif stage == STAGE_FINAL_ARGUMENT and speaker == SPEAKER_CON:
            return Command(
                update={},
                goto=NODE_JUDGE
            )

        raise ValueError(f"Unexpected stage/speaker combo: stage={stage}, speaker={speaker}")

Each conditional corresponds to a point in the debate where a turn just ended, and sets up the next turn. For example, after the opening (Pro just spoke), it sets stage to rebuttal, switches speaker to Con, and directs the workflow to the Con debater node​. After the final_argument (Con’s closing), it directs to the Judge with no further update (the debate stage effectively ends).

Fact Check Router (FactCheckRouterNode)

This is another control node (like the Moderator) that introduces conditional logic. The Fact Check Router sits right after the Fact Checker agent in the flow. Its purpose is to branch the workflow depending on the fact-check result.

In <a href="https://github.com/iason-solomos/Deb8flow/blob/main/nodes/fact_check_router_node.py">nodes/fact_check_router_node.py</a>, the logic is:

if pro_fact_checks >= 3 or con_fact_checks >= 3:
            disqualified = SPEAKER_PRO if pro_fact_checks >= 3 else SPEAKER_CON
            winner = SPEAKER_CON if disqualified == SPEAKER_PRO else SPEAKER_PRO

            verdict_msg = {
                "speaker": "moderator",
                "content": (
                    f"Debate ended early due to excessive factual inaccuracies.\n\n"
                    f"DISQUALIFIED: {disqualified.upper()} (exceeded fact check limit)\n"
                    f"WINNER: {winner.upper()}"
                ),
                "validated": True,
                "stage": "verdict"
            }
            return Command(
                update={"messages": messages + [verdict_msg]},
                goto=END
            )
        if last_message.get("validated"):
            return Command(goto=NODE_DEBATE_MODERATOR)
        elif speaker == SPEAKER_PRO:
            return Command(goto=NODE_PRO_DEBATER)
        elif speaker == SPEAKER_CON:
            return Command(goto=NODE_CON_DEBATER)
        raise ValueError("Unable to determine routing in FactCheckRouterNode.")

First, the Fact Check Router checks if either side’s fact-check count has reached 3. If so, it creates a Moderator-style message announcing an early end: the offending side is disqualified and the other side is the winner​. It appends this verdict to the messages and returns a Command that jumps to END, effectively terminating the debate without going to the Judge (because we already know the outcome).

If we’re not ending the debate early, it then looks at the Fact Checker’s result for the last message (which is stored as validated on that message). If validated is True, we go to the debate moderator: Command(goto=debate_moderator_node).

Else if the statement fails fact-check, the workflow goes back to the debater to produce a revised statement (with the state counters updated to reflect the failure). This loop can happen multiple times if needed (up to the disqualification limit).

This dynamic control is the heart of Deb8flow’s “agentic” nature – the ability to adapt the path of execution based on the content of the agents’ outputs. It showcases LangGraph’s strength: combining control flow with state. We’re essentially encoding debate rules (like allowing retries for false claims, or ending the debate if someone cheats too often) directly into the workflow graph.

Judge Agent (JudgeNode)

Last but not least, the Judge agent delivers the final verdict based on rhetorical skill, clarity, structure, and overall persuasiveness. Its system prompt and human prompt make this explicit:

  • System Prompt: “You are an impartial debate judge AI. … Evaluate which debater presented their case more clearly, persuasively, and logically. You must focus on communication skills, structure of argument, rhetorical strength, and overall coherence.”
  • Human Prompt: “Here is the full debate transcript. Please analyze the performance of both debaters—PRO and CON. Evaluate rhetorical performance—clarity, structure, persuasion, and relevance—and decide who presented their case more effectively.”

When the Judge node runs, it receives the entire debate transcript (all validated messages) alongside the original topic. It then uses GPT-4o to examine how each side framed their arguments, handled counterpoints, and supported (or failed to support) claims with examples or logic. Crucially, the Judge is forbidden to evaluate which position is objectively correct (or who it thinks might be correct)—only who argued more persuasively.

Below is an example final verdict from a Deb8flow run on the topic:
“Should governments implement a universal basic income in response to increasing automation in the workforce?”

WINNER: PRO

REASON: The PRO debater presented a more compelling and rhetorically effective case for universal basic income. Their arguments were well-structured, beginning with a clear statement of the issue and the necessity of UBI in response to automation. They effectively addressed potential counterarguments by highlighting the unprecedented speed and scope of current technological changes, which distinguishes the current situation from past technological shifts. The PRO also provided empirical evidence from UBI pilot programs to counter the CON's claims about work disincentives and economic inefficiencies, reinforcing their argument with real-world examples.

In contrast, the CON debater, while presenting valid concerns about UBI, relied heavily on historical analogies and assumptions about workforce adaptability without adequately addressing the unique challenges posed by modern automation. Their arguments about the fiscal burden and potential inefficiencies of UBI were less supported by specific evidence compared to the PRO's rebuttals.

Overall, the PRO's arguments were more coherent, persuasive, and backed by empirical evidence, making their case more convincing to a neutral observer.

Langsmith Tracing

Throughout Deb8flow’s development, I relied on LangSmith (LangChain’s tracing and observability toolkit) to ensure the entire debate pipeline was behaving correctly. Because we have multiple agents passing control between themselves, it’s easy for unexpected loops or misrouted states to occur. LangSmith provides a convenient way to:

  • Visualize Execution Flow: You can see each agent’s prompt, the tokens consumed (so you can also track costs), and any intermediate states. This makes it much simpler to confirm that, say, the Con Debater is properly referencing the Pro Debater’s last message, or that the Fact Checker is accurately receiving the claim to verify.
  • Debug State Updates: If the Moderator or Fact Check Router is sending the flow to the wrong node, the trace will highlight that mismatch. You can trace which agent was invoked at each step and why, helping you spot stage or speaker misalignments early.
  • Track Prompt and Completion Tokens: With multiple GPT-4o calls, it’s useful to see how many tokens each stage is using, which LangSmith logs automatically if you enable tracing.

Integrating LangSmith is unexpectedly easy. You will just need to provide these 3 keys in your .env file: LANGCHAIN_API_KEY

LANGCHAIN_TRACING_V2

LANGCHAIN_PROJECT

Then you can open the LangSmith UI to see a structured trace of each run. This greatly reduces the guesswork involved in debugging multi-agent systems and is, in my experience, essential for more complex AI orchestration like ours. Example of a single run:

The trace in waterfall mode in Lansmith of one run, showing how the whole flow ran. Source: Generated by the author using Langsmith.

Reflections and Next Steps

Building Deb8flow was an eye-opening exercise in orchestrating autonomous agent workflows. We didn’t just chain a single model call – we created an entire debate simulation with AI agents, each with a specific role, and allowed them to interact according to a set of rules. LangGraph provided a clear framework to define how data and control flows between agents, making the complex sequence manageable in code. By using class-based agents and a shared state, we maintained modularity and clarity, which will pay off for any software engineering project in the long run.

An exciting aspect of this project was seeing emergent behavior. Even though each agent follows a script (a prompt), the unscripted combination – a debater trying to deceive, a fact-checker catching it, the debater rephrasing – felt surprisingly realistic! It’s a small step toward more Agentic Ai systems that can perform non-trivial multi-step tasks with oversight on each other.

There’s plenty of ideas for improvement:

  • User Interaction: Currently it’s fully autonomous, but one could add a mode where a human provides the topic or even takes the role of one side against an AI opponent.
  • We can switch the order in which the Debaters talk.
  • We can change the prompts, and thus to a good degree the behavior of the agents, and experiment with different prompts.
  • Make the debaters also perform web search before producing their statements, thus providing them with the latest information.

The broader implication of Deb8flow is how it showcases a pattern for composable AI agents. By defining clear boundaries and interactions (just like microservices in software), we can have complex AI-driven processes that remain interpretable and controllable. Each agent is like a cog in a machine, and LangGraph is the gear system making them work in unison.

I found this project energizing, and I hope it inspires you to explore multi-agent workflows. Whether it’s debating, collaborating on writing, or solving problems from different expert angles, the combination of GPT, tools, and structured agentic workflows opens up a new world of possibilities for AI development. Happy hacking!

References

[1] D. Bouchard, “From Basics to Advanced: Exploring LangGraph,” Medium, Nov. 22, 2023. [Online]. Available: https://medium.com/data-science/from-basics-to-advanced-exploring-langgraph-e8c1cf4db787. [Accessed: Apr. 1, 2025].

[2] A. W. T. Ng, “Building a Research Agent that Can Write to Google Docs: Part 1,” Towards Data Science, Jan. 11, 2024. [Online]. Available: https://towardsdatascience.com/building-a-research-agent-that-can-write-to-google-docs-part-1-4b49ea05a292/. [Accessed: Apr. 1, 2025].

The post Deb8flow: Orchestrating Autonomous AI Debates with LangGraph and GPT-4o appeared first on Towards Data Science.

]]>
Why CatBoost Works So Well: The Engineering Behind the Magic https://towardsdatascience.com/catboost-inner-workings-and-optimizations/ Thu, 10 Apr 2025 00:28:11 +0000 https://towardsdatascience.com/?p=605702 CatBoost stands out by directly tackling a long-standing challenge in gradient boosting—how to handle categorical variables effectively without causing target leakage. By introducing innovative techniques such as Ordered Target Statistics and Ordered Boosting, and by leveraging the structure of Oblivious Trees, CatBoost efficiently balances robustness and accuracy. These methods ensure that each prediction uses only past data, preventing leakage and resulting in a model that is both fast and reliable for real-world tasks.

The post Why CatBoost Works So Well: The Engineering Behind the Magic appeared first on Towards Data Science.

]]>

Gradient boosting is a cornerstone technique for modeling tabular data due to its speed and simplicity. It delivers great results without any fuss. When you look around you’ll see multiple options like LightGBM, XGBoost, etc. Catboost is one such variant. In this post, we will take a detailed look at this model, explore its inner workings, and understand what makes it a great choice for real-world tasks.

Target Statistic

Table illustrating target encoding for categorical values. It maps vehicle types—Car, Bike, Bus, and Cycle—to numerical target means: 3.9, 1.2, 11.7, and 0.8 respectively. A curved arrow at the bottom indicates the transformation from category to numeric value
Target Encoding Example: the average value of the target variable for a category is used to replace each category. Image by author


Target Encoding Example: the average value of the target variable for a category is used to replace each category


One of the important contributions of the CatBoost paper is a new method of calculating the Target Statistic. What is a Target Statistic? If you have worked with categorical variables before, you’d know that the most rudimentary way to deal with categorical variables is to use one-hot encoding. From experience, you’d also know that this introduces a can of problems like sparsity, curse of dimensionality, memory issues, etc. Especially for categorical variables with high cardinality.

Greedy Target Statistic

To avoid one-hot encoding, we calculate the Target Statistic instead for the categorical variables. This means we calculate the mean of the target variable at each unique value of the categorical variable. So if a categorical variable takes the values — A, B, C then we will calculate the average value of \(\text{y}\) over all these values and replace these values with the average of \(\text{y}\) at each unique value.

That sounds good, right? It does but this approach comes with its problems — namely Target Leakage. To understand this, let’s take an extreme example. Extreme examples are often the easiest way to eke out issues in the approach. Consider the below dataset:

Categorical ColumnTarget Column
A0
B1
C0
D1
E0
Greedy Target Statistic: Compute the mean target value for each unique category


Now let’s write the equation for calculating the Target Statistic:
\[\hat{x}^i_k = \frac{
\sum_{j=1}^{n} 1_{{x^i_j = x^i_k}} \cdot y_j + a p
}{
\sum_{j=1}^{n} 1_{{x^i_j = x^i_k}} + a
}\]

Here \(x^i_j\) is the value of the i-th categorical feature for the j-th sample. So for the k-th sample, we iterate over all samples of \(x^i\), select the ones having the value \(x^i_k\), and take the average value of \(y\) over those samples. Instead of taking a direct average, we take a smoothened average which is what the \(a\) and \(p\) terms are for. The \(a\) parameter is the smoothening parameter and \(p\) is the global mean of \(y\).

If we calculate the Target Statistic using the formula above, we get:

Categorical ColumnTarget ColumnTarget Statistic
A0\(\frac{ap}{1+a}\)
B1\(\frac{1+ap}{1+a}\)
C0\(\frac{ap}{1+a}\)
D1\(\frac{1+ap}{1+a}\)
E0\(\frac{ap}{1+a}\)
Calculation of Greedy Target Statistic with Smoothening


Now if I use this Target Statistic column as my training data, I will get a perfect split at \( threshold = \frac{0.5+ap}{1+a}\). Anything above this value will be classified as 1 and anything below will be classified as 0. I have a perfect classification at this point, so I get 100% accuracy on my training data.

Let’s take a look at the test data. Here, since we are assuming that the feature has all unique values, the Target Statistic becomes—
\[TS = \frac{0+ap}{0+a} = p\]
If \(threshold\) is greater than \(p\), all test data predictions will be \(0\). Conversely, if \(threshold\) is less than \(p\), all test data predictions will be \(1\) leading to poor performance on the test set.

Although we rarely see datasets where values of a categorical variable are all unique, we do see cases of high cardinality. This extreme example shows the pitfalls of using Greedy Target Statistic as an encoding approach.

Leave One Out Target Statistic

So the Greedy TS didn’t work out quite well for us. Let’s try another method— the Leave One Out Target Statistic method. At first glance, this looks promising. But, as it turns out, this too has its problems. Let’s see how with another extreme example. This time let’s assume that our categorical variable \(x^i\) has only one unique value, i.e., all values are the same. Consider the below data:

Categorical ColumnTarget Column
A0
A1
A0
A1
Example data for an extreme case where a categorical feature has just one unique value


If calculate the leave one out target statistic, we get:

Categorical ColumnTarget ColumnTarget Statistic
A0\(\frac{n^+ -y_k + ap}{n+a}\)
A1\(\frac{n^+ -y_k + ap}{n+a}\)
A0\(\frac{n^+ -y_k + ap}{n+a}\)
A1\(\frac{n^+ -y_k + ap}{n+a}\)
Calculation of Leave One Out Target Statistic with Smoothening


Here:
\(n\) is the total samples in the data (in our case this 4)
\(n^+\) is the number of positive samples in the data (in our case this 2)
\(y_k\) is the value of the target column in that row
Substituting the above, we get:

Categorical ColumnTarget ColumnTarget Statistic
A0\(\frac{2 + ap}{4+a}\)
A1\(\frac{1 + ap}{4+a}\)
A0\(\frac{2 + ap}{4+a}\)
A1\(\frac{1 + ap}{4+a}\)
Substituing values of n and n<sup>+</sup>


Now, if I use this Target Statistic column as my training data, I will get a perfect split at \( threshold = \frac{1.5+ap}{4+a}\). Anything above this value will be classified as 0 and anything below will be classified as 1. I have a perfect classification at this point, so I again get 100% accuracy on my training data.

You see the problem, right? My categorical variable which doesn’t have more than a unique value is producing different values for Target Statistic which will perform great on the training data but will fail miserably on the test data.

Ordered Target Statistic

Illustration of ordered learning: CatBoost processes data in a randomly permuted order and predicts each sample using only the earlier samples (Image by Author)
Illustration of ordered learning: CatBoost processes data in a randomly permuted order and predicts each sample using only the earlier samples. Image by author

CatBoost introduces a technique called Ordered Target Statistic to address the issues discussed above. This is the core principle of CatBoost’s handling of categorical variables.

This method, inspired by online learning, uses only past data to make predictions. CatBoost generates a random permutation (random ordering) of the training data(\(\sigma\)). To compute the Target Statistic for a sample at row \(k\), CatBoost uses samples from row \(1\) to \(k-1\). For the test data, it uses the entire train data to compute the statistic.

Additionally, CatBoost generates a new permutation for each tree, rather than reusing the same permutation each time. This reduces the variance that can arise in the early samples.

Ordered Boosting

Diagram illustrating the ordered boosting mechanism in CatBoost. Data points x₁ through xᵢ are shown sequentially, with earlier samples used to compute predictions for later ones. Each xᵢ is associated with a model prediction M, where the prediction for xᵢ is computed using the model trained on previous data points. The equations show how residuals are calculated and how the model is updated: rᵗ(xᵢ, yᵢ) = yᵢ − M⁽ᵗ⁻¹⁾ᵢ⁻¹(xᵢ), and ΔM is learned from samples with order less than or equal to i. Final model update: Mᵢ = Mᵢ + ΔM.
This visualization shows how CatBoost computes residuals and updates the model: for sample xᵢ, the model predicts using only earlier data points. Source

Another important innovation introduced by the CatBoost paper is its use of Ordered Boosting. It builds on similar principles as ordered target statistics, where CatBoost randomly permutes the training data at the start of each tree and makes predictions sequentially.

In traditional boosting methods, when training tree \(t\), the model uses predictions from the previous tree \(t−1\) for all training samples, including the one it is currently predicting. This can lead to target leakage, as the model may indirectly use the label of the current sample during training.

To address this issue, CatBoost uses Ordered Boosting where, for a given sample, it only uses predictions from previous rows in the training data to calculate gradients and build trees. For each row \(i\) in the permutation, CatBoost calculates the output value of a leaf using only the samples before \(i\). The model uses this value to get the prediction for row \(i\). Thus, the model predicts each row without looking at its label.

CatBoost trains each tree using a new random permutation to average the variance in early samples in one permutation.
Let’s say we have 5 data points: A, B, C, D, E. CatBoost creates a random permutation of these points. Suppose the permutation is: σ = [C, A, E, B, D]

StepData Used to TrainData Point Being PredictedNotes
1CNo previous data → use prior
2CAModel trained on C only
3C, AEModel trained on C, A
4C, A, EBModel trained on C, A, E
5C, A, E, BDModel trained on C, A, E, B
Table highlighting how CatBoost uses random permutation to perform training

This avoids using the actual label of the current row to get the prediction thus preventing leakage.

Building a Tree

Each time CatBoost builds a tree, it creates a random permutation of the training data. It calculates the ordered target statistic for all the categorical variables with more than two unique values. For a binary categorical variable, it maps the values to zeros and ones.

CatBoost processes data as if the data is arriving sequentially. It begins with an initial prediction of zero for all instances, meaning the residuals are initially equivalent to the target values.

As training proceeds, CatBoost updates the leaf output for each sample using the residuals of the previous samples that fall into the same leaf. By not using the current sample’s label for prediction, CatBoost effectively prevents data leakage.

Split Candidates

Histogram showing how continuous features can be divided into bins—CatBoost evaluates splits using these binned values instead of raw continuous values
CatBoost bins continuous features to reduce the search space for optimal splits. Each bin edge and split point represents a potential decision threshold. Image by author

At the core of a decision tree lies the task of selecting the optimal feature and threshold for splitting a node. This involves evaluating multiple feature-threshold combinations and selecting the one that gives the best reduction in loss. CatBoost does something similar. It discretizes the continuous variable into bins to simplify the search for the optimal combination. It evaluates each of these feature-bin combinations to determine the best split

CatBoost uses Oblivious Trees, a key difference compared to other trees, where it uses the same split across all nodes at the same depth.

Oblivious Trees

Comparison between Oblivious Trees and Regular Trees. The Oblivious Tree on the left applies the same split condition at each level across all nodes, resulting in a symmetric structure. The Regular Tree on the right applies different conditions at each node, leading to an asymmetric structure with varied splits at different depths
Illustration of ordered learning: CatBoost processes data in a randomly permuted order and predicts each sample using only the earlier samples. Image by author

Unlike standard decision trees, where different nodes can split on different conditions (feature-threshold), Oblivious Trees split across the same conditions across all nodes at the same depth of a tree. At a given depth, all samples are evaluated at the same feature-threshold combination. This symmetry has several implications:

  • Speed and simplicity: since the same condition is applied across all nodes at the same depth, the trees produced are simpler and faster to train
  • Regularization: Since all trees are forced to apply the same condition across the tree at the same depth, there is a regularization effect on the predictions
  • Parallelization: the uniformity of the split condition, makes it easier to parallelize the tree creation and usage of GPU to accelerate training

Conclusion

CatBoost stands out by directly tackling a long-standing challenge: how to handle categorical variables effectively without causing target leakage. Through innovations like Ordered Target Statistics, Ordered Boosting, and the use of Oblivious Trees, it efficiently balances robustness and accuracy.

If you found this deep dive helpful, you might enjoy another deep dive on the differences between Stochastic Gradient Classifer and Logistic Regression

Further Reading

The post Why CatBoost Works So Well: The Engineering Behind the Magic appeared first on Towards Data Science.

]]>
Time Series Forecasting Made Simple (Part 1): Decomposition and Baseline Models https://towardsdatascience.com/time-series-forecasting-made-simple-part-1-decomposition-baseline-models/ Wed, 09 Apr 2025 19:53:52 +0000 https://towardsdatascience.com/?p=605699 Learn the intuition behind time series decomposition, additive vs. multiplicative models and build your first forecasting baseline model using Python

The post Time Series Forecasting Made Simple (Part 1): Decomposition and Baseline Models appeared first on Towards Data Science.

]]>
I used to avoid time series analysis. Every time I took an online course, I’d see a module titled “Time Series Analysis” with subtopics like Fourier Transforms, autocorrelation functions and other intimidating terms. I don’t know why, but I always found a reason to avoid it.

But here’s what I’ve learned: any complex topic becomes manageable when we start from the basics and focus on understanding the intuition behind it. That’s exactly what this blog series is about : making time series feel less like a maze and more like a conversation with your data over time.

We understand complex topics much more easily when they’re explained through real-world examples. That’s exactly how I’ll approach this series.

In each post, we’ll work with a simple dataset and explore what’s needed from a time series perspective. We’ll build intuition around each concept, understand why it matters, and implement it step by step on the data.

Time Series Analysis is the process of understanding, modeling and Forecasting data that is observed over time. It involves identifying patterns such as trends, seasonality and noise using past observations to make informed predictions about future values.

Let’s start by considering a dataset named Daily Minimum Temperatures in Melbourne (open license). This dataset contains daily records of the lowest temperature (in Celsius) observed in Melbourne, Australia, over a 10-year period from 1981 to 1990. Each entry includes just two columns:

Date: The calendar day (from 1981-01-01 to 1990-12-31)
Temp: The minimum temperature recorded on that day

You’ve probably heard of models like ARIMA, SARIMA or Exponential Smoothing. But before we go there, it’s a good idea to try out some simple baseline models first, to see how well a basic approach performs on our data.

While there are many types of baseline models used in time series forecasting, here we’ll focus on the three most essential ones, which are simple, effective, and widely applicable across industries.

Naive Forecast: Assumes the next value will be the same as the last observed one.
Seasonal Naive Forecast: Assumes the value will repeat from the same point last season (e.g., last week or last month).
Moving Average: Takes the average of the last n points.

You might be wondering, why use baseline models at all? Why not just go straight to the well-known forecasting methods like ARIMA or SARIMA?

Let’s consider a shop owner who wants to forecast next month’s sales. By applying a moving average baseline model, they can estimate next month’s sales as the average of previous months. This simple approach might already deliver around 80% accuracy — good enough for planning and inventory decisions.

Now, if we switch to a more advanced model like ARIMA or SARIMA, we might increase accuracy to around 85%. But the key question is: is that extra 5% worth the additional time, effort and resources? In this case, the baseline model does the job.

In fact, in most everyday business scenarios, baseline models are sufficient. We typically turn to classical models like ARIMA or SARIMA in high-impact industries such as finance or energy, where even a small improvement in accuracy can have a significant financial or operational impact. Even then, a baseline model is usually applied first — not only to provide quick insights but also to act as a benchmark that more complex models must outperform.

Okay, now that we’re ready to implement some baseline models, there’s one key thing we need to understand first:
Every time series is made up of three main components — trend, seasonality and residuals.

Time series decomposition separates data into trend, seasonality and residuals (noise), helping us uncover the true patterns beneath the surface. This understanding guides the choice of forecasting models and improves accuracy. It’s also a vital first step before building both simple and advanced forecasting solutions.

Trend
This is the overall direction your data is moving in over time — going up, down or staying flat.
Example: Steady decrease in monthly cigarette sales.

Seasonality
These are the patterns that repeat at regular intervals — daily, weekly, monthly or yearly.
Example: Cool drinks sales in summer.

Residuals (Noise)
This is the random “leftover” part of the data, the unpredictable ups and downs that can’t be explained by trend or seasonality.
Example: A one-time car purchase showing up in your monthly expense pattern.

Now that we understand the key components of a time series, let’s put that into practice using a real dataset: Daily Minimum Temperatures in Melbourne, Australia.

We’ll use Python to decompose the time series into its trend, seasonality, and residual components so we can better understand its structure and choose an appropriate baseline model.

Code:

import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.seasonal import seasonal_decompose

# Load the dataset
df = pd.read_csv("minimum daily temperatures data.csv")

# Convert 'Date' to datetime and set as index
df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
df.set_index('Date', inplace=True)

# Set a regular daily frequency and fill missing values using forward fill
df = df.asfreq('D')
df['Temp'].fillna(method='ffill', inplace=True)

# Decompose the daily series (365-day seasonality for yearly patterns)
decomposition = seasonal_decompose(df['Temp'], model='additive', period=365)

# Plot the decomposed components
decomposition.plot()
plt.suptitle('Decomposition of Daily Minimum Temperatures (Daily)', fontsize=14)
plt.tight_layout()
plt.show()

Output:

Decomposition of daily temperatures showing trend, seasonal cycles and random fluctuations.

The decomposition plot clearly shows a strong seasonal pattern that repeats each year, along with a mild trend that shifts over time. The residual component captures the random noise that isn’t explained by trend or seasonality.

In the code earlier, you might have noticed that I used an additive model for decomposing the Time Series. But what exactly does that mean — and why is it the right choice for this dataset?

Let’s break it down.
In an additive model, we assume Trend, Seasonality and Residuals (Noise) combine linearly, like this:
Y = T ​+ S ​+ R​

Where:
Y is the actual value at time t
T​ is the trend
S is the seasonal component
R is the residual (random noise)

This means we’re treating the observed value as the sum of the parts, each component contributes independently to the final output.

I chose the additive model because when I looked at the pattern in daily minimum temperatures, I noticed something important:

The line plot above shows the daily minimum temperatures from 1981 to 1990. We can clearly see a strong seasonal cycle that repeats each year, colder temperatures in winter, warmer in summer.

Importantly, the amplitude of these seasonal swings stays relatively consistent over the years. For example, the temperature difference between summer and winter doesn’t appear to grow or shrink over time. This stability in seasonal variation is a key sign that the additive model is appropriate for decomposition, since the seasonal component appears to be independent of any trend.

We use an additive model when the trend is relatively stable and does not amplify or distort the seasonal pattern, and when the seasonality stays within a consistent range over time, even if there are minor fluctuations.

Now that we understand how the additive model works, let’s explore the multiplicative model — which is often used when the seasonal effect scales with the trend which will also help us understand the additive model more clearly.

Consider a household’s electricity consumption. Suppose the household uses 20% more electricity in summer compared to winter. That means the seasonal effect isn’t a fixed number — it’s a proportion of their baseline usage.

Let’s see how this looks with real numbers:

In 2021, the household used 300 kWh in winter and 360 kWh in summer (20% more than winter).

In 2022, their winter consumption increased to 330 kWh, and summer usage rose to 396 kWh (still 20% more than winter).

In both years, the seasonal difference grows with the trend   from +60 kWh in 2021 to +66 kWh in 2022   even though the percentage increase stays the same. This is exactly the kind of behavior that a multiplicative model captures well.

In mathematical terms:
Y = T ×S ×R 
Where:
Y​: Observed value
T: Trend component
S: Seasonal component
R​: Residual (noise)

By looking at the decomposition plot, we can figure out whether an additive or multiplicative model fits our data better.

There are also other powerful decomposition tools available, which I’ll be covering in one of my upcoming blog posts.Now that we have a clear understanding of additive and multiplicative models, let’s shift our focus to applying a baseline model that fits this dataset.

Based on the decomposition plot, we can see a strong seasonal pattern in the data, which suggests that a Seasonal Naive model might be a good fit for this time series.

This model assumes that the value at a given time will be the same as it was in the same period of the previous season — making it a simple yet effective choice when seasonality is dominant and consistent. For example, if temperatures typically follow the same yearly cycle, then the forecast for July 1st, 1990, would simply be the temperature recorded on July 1st, 1989.

Code:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load the dataset
df = pd.read_csv("minimum daily temperatures data.csv")

# Convert 'Date' column to datetime and set as index
df['Date'] = pd.to_datetime(df['Date'], dayfirst=True)
df.set_index('Date', inplace=True)

# Ensure regular daily frequency and fill missing values
df = df.asfreq('D')
df['Temp'].fillna(method='ffill', inplace=True)

# Step 1: Create the Seasonal Naive Forecast
seasonal_period = 365  # Assuming yearly seasonality for daily data
# Create the Seasonal Naive forecast by shifting the temperature values by 365 days
df['Seasonal_Naive'] = df['Temp'].shift(seasonal_period)

# Step 2: Plot the actual vs forecasted values
# Plot the last 2 years (730 days) of data to compare
plt.figure(figsize=(12, 5))
plt.plot(df['Temp'][-730:], label='Actual')
plt.plot(df['Seasonal_Naive'][-730:], label='Seasonal Naive Forecast', linestyle='--')
plt.title('Seasonal Naive Forecast vs Actual Temperatures')
plt.xlabel('Date')
plt.ylabel('Temperature (°C)')
plt.legend()
plt.tight_layout()
plt.show()

# Step 3: Evaluate using MAPE (Mean Absolute Percentage Error)
# Use the last 365 days for testing
test = df[['Temp', 'Seasonal_Naive']].iloc[-365:].copy()
test.dropna(inplace=True)

# MAPE Calculation
mape = np.mean(np.abs((test['Temp'] - test['Seasonal_Naive']) / test['Temp'])) * 100
print(f"MAPE (Seasonal Naive Forecast): {mape:.2f}%")

Output:

Seasonal Naive Forecast vs. Actual Temperatures (1989–1990)


To keep the visualization clear and focused, we’ve plotted the last two years of the dataset (1989–1990) instead of all 10 years.

This plot compares the actual daily minimum temperatures in Melbourne with the values predicted by the Seasonal Naive model, which simply assumes that each day’s temperature will be the same as it was on the same day one year ago.

As seen in the plot, the Seasonal Naive forecast captures the broad shape of the seasonal cycles quite well — it mirrors the rise and fall of temperatures throughout the year. However, it doesn’t capture day-to-day variations, nor does it respond to slight shifts in seasonal timing. This is expected, as the model is designed to repeat the previous year’s pattern exactly, without adjusting for trend or noise.

To evaluate how well this model performs, we calculate the Mean Absolute Percentage Error (MAPE) over the final 365 days of the dataset (i.e., 1990). We only use this period because the Seasonal Naive forecast needs a full year of historical data before it can begin making predictions.

Mean Absolute Percentage Error (MAPE) is a commonly used metric to evaluate the accuracy of forecasting models. It measures the average absolute difference between the actual and predicted values, expressed as a percentage of the actual values.

In time series forecasting, we typically evaluate model performance on the most recent or target time period — not on the middle years. This reflects how forecasts are used in the real world: we build models on historical data to predict what’s coming next.

That’s why we calculate MAPE only on the final 365 days of the dataset — this simulates forecasting for a future and gives us a realistic measure of how well the model would perform in practice.

A MAPE of 28.23%, which gives us a baseline level of forecasting error. Any model we build next — whether it’s customized or more advanced, should aim to outperform this benchmark.

A MAPE of 28.23% means that, on average, the model’s predictions were 28.23% off from the actual daily temperature values over the last year.

In other words, if the true temperature on a given day was 10°C, the Seasonal Naïve forecast might have been around 7.2°C or 12.8°C, reflecting a 28% deviation.

I’ll dive deeper into evaluation metrics in a future post.

In this post, we laid the foundation for time series forecasting by understanding how real-world data can be broken down into trend, seasonality, and residuals through decomposition. We explored the difference between additive and multiplicative models, implemented the Seasonal Naive baseline forecast and evaluated its performance using MAPE.

While the Seasonal Naive model is simple and intuitive, it comes with limitations especially for this dataset. It assumes that the temperature on any given day is identical to the same day last year. But as the plot and MAPE of 28.23% showed, this assumption doesn’t hold perfectly. The data displays slight shifts in seasonal patterns and long-term variations that the model fails to capture.

In the next part of this series, we’ll go further. We’ll explore how to customize a baseline model, compare it to the Seasonal Naive approach and evaluate which one performs better using error metrics like MAPE, MAE and RMSE.

We’ll also begin building the foundation needed to understand more advanced models like ARIMA including key concepts such as:

  • Stationarity
  • Autocorrelation and Partial Autocorrelation 
  • Differencing
  • Lag-based modeling (AR and MA terms)

Part 2 will dive into these topics in more detail, starting with custom baselines and ending with the foundations of ARIMA.

Thanks for reading.  I hope you found this post helpful and insightful.

The post Time Series Forecasting Made Simple (Part 1): Decomposition and Baseline Models appeared first on Towards Data Science.

]]>