Mariya Mansurova, Author at Towards Data Science https://towardsdatascience.com The world’s leading publication for data science, AI, and ML professionals. Wed, 09 Apr 2025 16:54:57 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Mariya Mansurova, Author at Towards Data Science https://towardsdatascience.com 32 32 Mining Rules from Data https://towardsdatascience.com/mining-rules-from-data/ Wed, 09 Apr 2025 16:54:40 +0000 https://towardsdatascience.com/?p=605697 Using decision trees for quick segmentation

The post Mining Rules from Data appeared first on Towards Data Science.

]]>
Working with products, we might face a need to introduce some “rules”. Let me explain what I mean by “rules” in practical examples: 

  • Imagine that we’re seeing a massive wave of fraud in our product, and we want to restrict onboarding for a particular segment of customers to lower this risk. For example, we found out that the majority of fraudsters had specific user agents and IP addresses from certain countries. 
  • Another option is to send coupons to customers to use in our online shop. However, we would like to treat only customers who are likely to churn since loyal users will return to the product anyway. We might figure out that the most feasible group is customers who joined less than a year ago and decreased their spending by 30%+ last month. 
  • Transactional businesses often have a segment of customers where they are losing money. For example, a bank customer passed the verification and regularly reached out to customer support (so generated onboarding and servicing costs) while doing almost no transactions (so not generating any revenue). The bank might introduce a small monthly subscription fee for customers with less than 1000$ in their account since they are likely non-profitable.

Of course, in all these cases, we might have used a complex Machine Learning model that would take into account all the factors and predict the probability (either of a customer being a fraudster or churning). Still, under some circumstances, we might prefer just a set of static rules for the following reasons:  

  • The speed and complexity of implementation. Deploying an ML model in production takes time and effort. If you are experiencing a fraud wave right now, it might be more feasible to go live with a set of static rules that can be implemented quickly and then work on a comprehensive solution. 
  • Interpretability. ML models are black boxes. Even though we might be able to understand at a high level how they work and what features are the most important ones, it’s challenging to explain them to customers. In the example of subscription fees for non-profitable customers, it’s important to share a set of transparent rules with customers so that they can understand the pricing. 
  • Compliance. Some industries, like finance or healthcare, might require auditable and rule-based decisions to meet compliance requirements.

In this article, I want to show you how we can solve business problems using such rules. We will take a practical example and go really deep into this topic:

  • we will discuss which models we can use to mine such rules from data,
  • we will build a Decision Tree Classifier from scratch to learn how it works,
  • we will fit the sklearn Decision Tree Classifier model to extract the rules from the data,
  • we will learn how to parse the Decision Tree structure to get the resulting segments,
  • finally, we will explore different options for category encoding, since the sklearn implementation doesn’t support categorical variables.

We have lots of topics to cover, so let’s jump into it.

Case

As usual, it’s easier to learn something with a practical example. So, let’s start by discussing the task we will be solving in this article. 

We will work with the Bank Marketing dataset (CC BY 4.0 license). This dataset contains data about the direct marketing campaigns of a Portuguese banking institution. For each customer, we know a bunch of features and whether they subscribed to a term deposit (our target). 

Our business goal is to maximise the number of conversions (subscriptions) with limited operational resources. So, we can’t call the whole user base, and we want to reach the best outcome with the resources we have.

The first step is to look at the data. So, let’s load the data set.

import pandas as pd
pd.set_option('display.max_colwidth', 5000)
pd.set_option('display.float_format', lambda x: '%.2f' % x)

df = pd.read_csv('bank-full.csv', sep = ';')
df = df.drop(['duration', 'campaign'], axis = 1)
# removed columns related to the current marketing campaign, 
# since they introduce data leakage

df.head()

We know quite a lot about the customers, including personal data (such as job type or marital status) and their previous behaviour (such as whether they have a loan or their average yearly balance).

Image by author

The next step is to select a machine-learning model. There are two classes of models that are usually used when we need something easily interpretable:

  • decision trees,
  • linear or logistic regression.

Both options are feasible and can give us good models that can be easily implemented and interpreted. However, in this article, I would like to stick to the decision tree model because it produces actual rules, while logistic regression will give us probability as a weighted sum of features.

Data Preprocessing 

As we’ve seen in the data, there are lots of categorical variables (such as education or marital status). Unfortunately, the sklearn decision tree implementation can’t handle categorical data, so we need to do some preprocessing.

Let’s start by transforming yes/no flags into integers. 

for p in ['default', 'housing', 'loan', 'y']:
    df[p] = df[p].map(lambda x: 1 if x == 'yes' else 0)

The next step is to transform the month variable. We can use one-hot encoding for months, introducing flags like month_jan , month_feb , etc. However, there might be seasonal effects, and I think it would be more reasonable to convert months into integers following their order. 

month_map = {
    'jan': 1, 'feb': 2, 'mar': 3, 'apr': 4, 'may': 5, 'jun': 6, 
    'jul': 7, 'aug': 8, 'sep': 9, 'oct': 10, 'nov': 11, 'dec': 12
}
# I saved 5 mins by asking ChatGPT to do this mapping

df['month'] = df.month.map(lambda x: month_map[x] if x in month_map else x)

For all other categorical variables, let’s use one-hot encoding. We will discuss different strategies for category encoding later, but for now, let’s stick to the default approach.

The easiest way to do one-hot encoding is to leverage get_dummies function in pandas.

fin_df = pd.get_dummies(
  df, columns=['job', 'marital', 'education', 'poutcome', 'contact'], 
  dtype = int, # to convert to flags 0/1
  drop_first = False # to keep all possible values
)

This function transforms each categorical variable into a separate 1/0 column for each possible. We can see how it works for poutcome column. 

fin_df.merge(df[['id', 'poutcome']])\
    .groupby(['poutcome', 'poutcome_unknown', 'poutcome_failure', 
      'poutcome_other', 'poutcome_success'], as_index = False).y.count()\
    .rename(columns = {'y': 'cases'})\
    .sort_values('cases', ascending = False)
Image by author

Our data is now ready, and it’s time to discuss how decision tree classifiers work.

Decision Tree Classifier: Theory

In this section, we’ll explore the theory behind the Decision Tree Classifier and build the algorithm from scratch. If you’re more interested in a practical example, feel free to skip ahead to the next part.

The easiest way to understand the decision tree model is to look at an example. So, let’s build a simple model based on our data. We will use DecisionTreeClassifier from sklearn

feature_names = fin_df.drop(['y'], axis = 1).columns
model = sklearn.tree.DecisionTreeClassifier(
  max_depth = 2, min_samples_leaf = 1000)
model.fit(fin_df[feature_names], fin_df['y'])

The next step is to visualise the tree.

dot_data = sklearn.tree.export_graphviz(
    model, out_file=None, feature_names = feature_names, filled = True, 
    proportion = True, precision = 2 
    # to show shares of classes instead of absolute numbers
)

graph = graphviz.Source(dot_data)
graph
Image by author

So, we can see that the model is straightforward. It’s a set of binary splits that we can use as heuristics. 

Let’s figure out how the classifier works under the hood. As usual, the best way to understand the model is to build the logic from scratch. 

The cornerstone of any problem is the optimisation function. By default, in the decision tree classifier, we’re optimising the Gini coefficient. Imagine getting one random item from the sample and then the other. The Gini coefficient would equal the probability of the situation when these items are from different classes. So, our goal will be minimising the Gini coefficient. 

In the case of just two classes (like in our example, where marketing intervention was either successful or not), the Gini coefficient is defined just by one parameter p , where p is the probability of getting an item from one of the classes. Here’s the formula:

\[\textbf{gini}(\textsf{p}) = 1 – \textsf{p}^2 – (1 – \textsf{p})^2 = 2 * \textsf{p} * (1 – \textsf{p}) \]

If our classification is ideal and we are able to separate the classes perfectly, then the Gini coefficient will be equal to 0. The worst-case scenario is when p = 0.5 , then the Gini coefficient is also equal to 0.5.

With the formula above, we can calculate the Gini coefficient for each leaf of the tree. To calculate the Gini coefficient for the whole tree, we need to combine the Gini coefficients of binary splits. For that, we can just get a weighted sum:

\[\textbf{gini}_{\textsf{total}} = \textbf{gini}_{\textsf{left}} * \frac{\textbf{n}_{\textsf{left}}}{\textbf{n}_{\textsf{left}} + \textbf{n}_{\textsf{right}}} + \textbf{gini}_{\textsf{right}} * \frac{\textbf{n}_{\textsf{right}}}{\textbf{n}_{\textsf{left}} + \textbf{n}_{\textsf{right}}}\]

Now that we know what value we’re optimising, we only need to define all possible binary splits, iterate through them and choose the best option. 

Defining all possible binary splits is also quite straightforward. We can do it one by one for each parameter, sort possible values, and pick up thresholds between them. For example, for months (integer from 1 to 12). 

Image by author

Let’s try to code it and see whether we will come to the same result. First, we will define functions that calculate the Gini coefficient for one dataset and the combination.

def get_gini(df):
    p = df.y.mean()
    return 2*p*(1-p)

print(get_gini(fin_df)) 
# 0.2065
# close to what we see at the root node of Decision Tree

def get_gini_comb(df1, df2):
    n1 = df1.shape[0]
    n2 = df2.shape[0]

    gini1 = get_gini(df1)
    gini2 = get_gini(df2)
    return (gini1*n1 + gini2*n2)/(n1 + n2)

The next step is to get all possible thresholds for one parameter and calculate their Gini coefficients. 

import tqdm
def optimise_one_parameter(df, param):
    tmp = []
    possible_values = list(sorted(df[param].unique()))
    print(param)

    for i in tqdm.tqdm(range(1, len(possible_values))): 
        threshold = (possible_values[i-1] + possible_values[i])/2
        gini = get_gini_comb(df[df[param] <= threshold], 
          df[df[param] > threshold])
        tmp.append(
            {'param': param, 
            'threshold': threshold, 
            'gini': gini, 
            'sizes': (df[df[param] <= threshold].shape[0], df[df[param] > threshold].shape[0]))
            }
        )
    return pd.DataFrame(tmp)

The final step is to iterate through all features and calculate all possible splits. 

tmp_dfs = []
for feature in feature_names:
    tmp_dfs.append(optimise_one_parameter(fin_df, feature))
opt_df = pd.concat(tmp_dfs)
opt_df.sort_values('gini', asceding = True).head(5)
Image by author

Wonderful, we’ve got the same result as in our DecisionTreeClassifier model. The optimal split is whether poutcome = success or not. We’ve reduced the Gini coefficient from 0.2065 to 0.1872. 

To continue building the tree, we need to repeat the process recursively. For example, going down for the poutcome_success <= 0.5 branch:

tmp_dfs = []
for feature in feature_names:
    tmp_dfs.append(optimise_one_parameter(
      fin_df[fin_df.poutcome_success <= 0.5], feature))

opt_df = pd.concat(tmp_dfs)
opt_df.sort_values('gini', ascending = True).head(5)
Image by author

The only question we still need to discuss is the stopping criteria. In our initial example, we’ve used two conditions:

  • max_depth = 2 — it just limits the maximum depth of the tree, 
  • min_samples_leaf = 1000 prevents us from getting leaf nodes with less than 1K samples. Because of this condition, we’ve chosen a binary split by contact_unknown even though age led to a lower Gini coefficient.

Also, I usually limit the min_impurity_decrease that prevent us from going further if the gains are too small. By gains, we mean the decrease of the Gini coefficient.

So, we’ve understood how the Decision Tree Classifier works, and now it’s time to use it in practice.

If you’re interested to see how Decision Tree Regressor works in all detail, you can look it up in my previous article.

Decision Trees: practice

We’ve already built a simple tree model with two layers, but it’s definitely not enough since it’s too simple to get all the insights from the data. Let’s train another Decision Tree by limiting the number of samples in leaves and decreasing impurity (reduction of Gini coefficient). 

model = sklearn.tree.DecisionTreeClassifier(
  min_samples_leaf = 1000, min_impurity_decrease=0.001)
model.fit(fin_df[features], fin_df['y'])

dot_data = sklearn.tree.export_graphviz(
    model, out_file=None, feature_names = features, filled = True, 
    proportion = True, precision=2, impurity = True)

graph = graphviz.Source(dot_data)

# saving graph to png file
png_bytes = graph.pipe(format='png')
with open('decision_tree.png','wb') as f:
    f.write(png_bytes)
Image by author

That’s it. We’ve got our rules to split customers into groups (leaves). Now, we can iterate through groups and see which groups of customers we want to contact. Even though our model is relatively small, it’s daunting to copy all conditions from the image. Luckily, we can parse the tree structure and get all the groups from the model.

The Decision Tree classifier has an attribute tree_ that will allow us to get access to low-level attributes of the tree, such as node_count .

n_nodes = model.tree_.node_count
print(n_nodes)
# 13

The tree_ variable also stores the entire tree structure as parallel arrays, where the ith element of each array stores the information about the node i. For the root i equals to 0.

Here are the arrays we have to represent the tree structure: 

  • children_left and children_right — IDs of left and right nodes, respectively; if the node is a leaf, then -1.
  • feature — feature used to split the node i .
  • threshold — threshold value used for the binary split of the node i .
  • n_node_samples — number of training samples that reached the node i .
  • values — shares of samples from each class.

Let’s save all these arrays. 

children_left = model.tree_.children_left
# [ 1,  2,  3,  4,  5,  6, -1, -1, -1, -1, -1, -1, -1]
children_right = model.tree_.children_right
# [12, 11, 10,  9,  8,  7, -1, -1, -1, -1, -1, -1, -1]
features = model.tree_.feature
# [30, 34,  0,  3,  6,  6, -2, -2, -2, -2, -2, -2, -2]
thresholds = model.tree_.threshold
# [ 0.5,  0.5, 59.5,  0.5,  6.5,  2.5, -2. , -2. , -2. , -2. , -2. , -2. , -2. ]
num_nodes = model.tree_.n_node_samples
# [45211, 43700, 30692, 29328, 14165,  4165,  2053,  2112, 10000, 
#  15163,  1364, 13008,  1511] 
values = model.tree_.value
# [[[0.8830152 , 0.1169848 ]],
# [[0.90135011, 0.09864989]],
# [[0.87671054, 0.12328946]],
# [[0.88550191, 0.11449809]],
# [[0.8530886 , 0.1469114 ]],
# [[0.76686675, 0.23313325]],
# [[0.87043351, 0.12956649]],
# [[0.66619318, 0.33380682]],
# [[0.889     , 0.111     ]],
# [[0.91578184, 0.08421816]],
# [[0.68768328, 0.31231672]],
# [[0.95948647, 0.04051353]],
# [[0.35274653, 0.64725347]]]

It will be more convenient for us to work with a hierarchical view of the tree structure, so let’s iterate through all nodes and, for each node, save the parent node ID and whether it was a right or left branch. 

hierarchy = {}

for node_id in range(n_nodes):
  if children_left[node_id] != -1: 
    hierarchy[children_left[node_id]] = {
      'parent': node_id, 
      'condition': 'left'
    }
  
  if children_right[node_id] != -1:
      hierarchy[children_right[node_id]] = {
       'parent': node_id, 
       'condition': 'right'
  }

print(hierarchy)
# {1: {'parent': 0, 'condition': 'left'},
# 12: {'parent': 0, 'condition': 'right'},
# 2: {'parent': 1, 'condition': 'left'},
# 11: {'parent': 1, 'condition': 'right'},
# 3: {'parent': 2, 'condition': 'left'},
# 10: {'parent': 2, 'condition': 'right'},
# 4: {'parent': 3, 'condition': 'left'},
# 9: {'parent': 3, 'condition': 'right'},
# 5: {'parent': 4, 'condition': 'left'},
# 8: {'parent': 4, 'condition': 'right'},
# 6: {'parent': 5, 'condition': 'left'},
# 7: {'parent': 5, 'condition': 'right'}}

The next step is to filter out the leaf nodes since they are terminal and the most interesting for us as they define the customer segments. 

leaves = []
for node_id in range(n_nodes):
    if (children_left[node_id] == -1) and (children_right[node_id] == -1):
        leaves.append(node_id)
print(leaves)
# [6, 7, 8, 9, 10, 11, 12]
leaves_df = pd.DataFrame({'node_id': leaves})

The next step is to determine all the conditions applied to each group since they will define our customer segments. The first function get_condition will give us the tuple of feature, condition type and threshold for a node. 

def get_condition(node_id, condition, features, thresholds, feature_names):
    # print(node_id, condition)
    feature = feature_names[features[node_id]]
    threshold = thresholds[node_id]
    cond = '>' if condition == 'right'  else '<='
    return (feature, cond, threshold)

print(get_condition(0, 'left', features, thresholds, feature_names)) 
# ('poutcome_success', '<=', 0.5)

print(get_condition(0, 'right', features, thresholds, feature_names))
# ('poutcome_success', '>', 0.5)

The next function will allow us to recursively go from the leaf node to the root and get all the binary splits. 

def get_decision_path_rec(node_id, decision_path, hierarchy):
  if node_id == 0:
    yield decision_path 
  else:
    parent_id = hierarchy[node_id]['parent']
    condition = hierarchy[node_id]['condition']
    for res in get_decision_path_rec(parent_id, decision_path + [(parent_id, condition)], hierarchy):
        yield res

decision_path = list(get_decision_path_rec(12, [], hierarchy))[0]
print(decision_path) 
# [(0, 'right')]

fmt_decision_path = list(map(
  lambda x: get_condition(x[0], x[1], features, thresholds, feature_names), 
  decision_path))
print(fmt_decision_path)
# [('poutcome_success', '>', 0.5)]

Let’s save the logic of executing the recursion and formatting into a wrapper function.

def get_decision_path(node_id, features, thresholds, hierarchy, feature_names):
  decision_path = list(get_decision_path_rec(node_id, [], hierarchy))[0]
  return list(map(lambda x: get_condition(x[0], x[1], features, thresholds, 
    feature_names), decision_path))

We’ve learned how to get each node’s binary split conditions. The only remaining logic is to combine the conditions. 

def get_decision_path_string(node_id, features, thresholds, hierarchy, 
  feature_names):
  conditions_df = pd.DataFrame(get_decision_path(node_id, features, thresholds, hierarchy, feature_names))
  conditions_df.columns = ['feature', 'condition', 'threshold']

  left_conditions_df = conditions_df[conditions_df.condition == '<=']
  right_conditions_df = conditions_df[conditions_df.condition == '>']

  # deduplication 
  left_conditions_df = left_conditions_df.groupby(['feature', 'condition'], as_index = False).min()
  right_conditions_df = right_conditions_df.groupby(['feature', 'condition'], as_index = False).max()
  
  # concatination
  fin_conditions_df = pd.concat([left_conditions_df, right_conditions_df])\
      .sort_values(['feature', 'condition'], ascending = False)
  
  # formatting 
  fin_conditions_df['cond_string'] = list(map(
      lambda x, y, z: '(%s %s %.2f)' % (x, y, z),
      fin_conditions_df.feature,
      fin_conditions_df.condition,
      fin_conditions_df.threshold
  ))
  return ' and '.join(fin_conditions_df.cond_string.values)

print(get_decision_path_string(12, features, thresholds, hierarchy, 
  feature_names))
# (poutcome_success > 0.50)

Now, we can calculate the conditions for each group. 

leaves_df['condition'] = leaves_df['node_id'].map(
  lambda x: get_decision_path_string(x, features, thresholds, hierarchy, 
  feature_names)
)

The last step is to add their size and conversion to the groups.

leaves_df['total'] = leaves_df.node_id.map(lambda x: num_nodes[x])
leaves_df['conversion'] = leaves_df['node_id'].map(lambda x: values[x][0][1])*100
leaves_df['converted_users'] = (leaves_df.conversion * leaves_df.total)\
  .map(lambda x: int(round(x/100)))
leaves_df['share_of_converted'] = 100*leaves_df['converted_users']/leaves_df['converted_users'].sum()
leaves_df['share_of_total'] = 100*leaves_df['total']/leaves_df['total'].sum()

Now, we can use these rules to make decisions. We can sort groups by conversion (probability of successful contact) and pick the customers with the highest probability. 

leaves_df.sort_values('conversion', ascending = False)\
  .drop('node_id', axis = 1).set_index('condition')
Image by author

Imagine we have resources to contact only around 10% of our user base, we can focus on the first three groups. Even with such a limited capacity, we would expect to get almost 40% conversion — it’s a really good result, and we’ve achieved it with just a bunch of straightforward heuristics.  

In real life, it’s also worth testing the model (or heuristics) before deploying it in production. I would split the training dataset into training and validation parts (by time to avoid leakage) and see the heuristics performance on the validation set to have a better view of the actual model quality.

Working with high cardinality categories

Another topic that is worth discussing in this context is category encoding, since we have to encode the categorical variables for sklearn implementation. We’ve used a straightforward approach with one-hot encoding, but in some cases, it doesn’t work.

Imagine we also have a region in the data. I’ve synthetically generated English cities for each row. We have 155 unique regions, so the number of features has increased to 190. 

model = sklearn.tree.DecisionTreeClassifier(min_samples_leaf = 100, min_impurity_decrease=0.001)
model.fit(fin_df[feature_names], fin_df['y'])

So, the basic tree now has lots of conditions based on regions and it’s not convenient to work with them.

Image by author

In such a case, it might not be meaningful to explode the number of features, and it’s time to think about encoding. There’s a comprehensive article, “Categorically: Don’t explode — encode!”, that shares a bunch of different options to handle high cardinality categorical variables. I think the most feasible ones in our case will be the following two options:

  • Count or Frequency Encoder that shows good performance in benchmarks. This encoding assumes that categories of similar size would have similar characteristics. 
  • Target Encoder, where we can encode the category by the mean value of the target variable. It will allow us to prioritise segments with higher conversion and deprioritise segments with lower. Ideally, it would be nice to use historical data to get the averages for the encoding, but we will use the existing dataset. 

However, it will be interesting to test different approaches, so let’s split our dataset into train and test, saving 10% for validation. For simplicity, I’ve used one-hot encoding for all columns except for region (since it has the highest cardinality).

from sklearn.model_selection import train_test_split
fin_df = pd.get_dummies(df, columns=['job', 'marital', 'education', 
  'poutcome', 'contact'], dtype = int, drop_first = False)
train_df, test_df = train_test_split(fin_df,test_size=0.1, random_state=42)
print(train_df.shape[0], test_df.shape[0])
# (40689, 4522)

For convenience, let’s combine all the logic for parsing the tree into one function.

def get_model_definition(model, feature_names):
  n_nodes = model.tree_.node_count
  children_left = model.tree_.children_left
  children_right = model.tree_.children_right
  features = model.tree_.feature
  thresholds = model.tree_.threshold
  num_nodes = model.tree_.n_node_samples
  values = model.tree_.value

  hierarchy = {}

  for node_id in range(n_nodes):
      if children_left[node_id] != -1: 
          hierarchy[children_left[node_id]] = {
            'parent': node_id, 
            'condition': 'left'
          }
    
      if children_right[node_id] != -1:
            hierarchy[children_right[node_id]] = {
             'parent': node_id, 
             'condition': 'right'
            }

  leaves = []
  for node_id in range(n_nodes):
      if (children_left[node_id] == -1) and (children_right[node_id] == -1):
          leaves.append(node_id)
  leaves_df = pd.DataFrame({'node_id': leaves})
  leaves_df['condition'] = leaves_df['node_id'].map(
    lambda x: get_decision_path_string(x, features, thresholds, hierarchy, feature_names)
  )

  leaves_df['total'] = leaves_df.node_id.map(lambda x: num_nodes[x])
  leaves_df['conversion'] = leaves_df['node_id'].map(lambda x: values[x][0][1])*100
  leaves_df['converted_users'] = (leaves_df.conversion * leaves_df.total).map(lambda x: int(round(x/100)))
  leaves_df['share_of_converted'] = 100*leaves_df['converted_users']/leaves_df['converted_users'].sum()
  leaves_df['share_of_total'] = 100*leaves_df['total']/leaves_df['total'].sum()
  leaves_df = leaves_df.sort_values('conversion', ascending = False)\
    .drop('node_id', axis = 1).set_index('condition')
  leaves_df['cum_share_of_total'] = leaves_df['share_of_total'].cumsum()
  leaves_df['cum_share_of_converted'] = leaves_df['share_of_converted'].cumsum()
  return leaves_df

Let’s create an encodings data frame, calculating frequencies and conversions. 

region_encoding_df = train_df.groupby('region', as_index = False)\
  .aggregate({'id': 'count', 'y': 'mean'}).rename(columns = 
    {'id': 'region_count', 'y': 'region_target'})

Then, merge it into our training and validation sets. For the validation set, we will also fill NAs as averages.

train_df = train_df.merge(region_encoding_df, on = 'region')

test_df = test_df.merge(region_encoding_df, on = 'region', how = 'left')
test_df['region_target'] = test_df['region_target']\
  .fillna(region_encoding_df.region_target.mean())
test_df['region_count'] = test_df['region_count']\
  .fillna(region_encoding_df.region_count.mean())

Now, we can fit the models and get their structures.

count_feature_names = train_df.drop(
  ['y', 'id', 'region_target', 'region'], axis = 1).columns
target_feature_names = train_df.drop(
  ['y', 'id', 'region_count', 'region'], axis = 1).columns
print(len(count_feature_names), len(target_feature_names))
# (36, 36)

count_model = sklearn.tree.DecisionTreeClassifier(min_samples_leaf = 500, 
  min_impurity_decrease=0.001)
count_model.fit(train_df[count_feature_names], train_df['y'])

target_model = sklearn.tree.DecisionTreeClassifier(min_samples_leaf = 500, 
  min_impurity_decrease=0.001)
target_model.fit(train_df[target_feature_names], train_df['y'])

count_model_def_df = get_model_definition(count_model, count_feature_names)
target_model_def_df = get_model_definition(target_model, target_feature_names)

Let’s look at the structures and select the top categories up to 10–15% of our target audience. We can also apply these conditions to our validation sets to test our approach in practice. 

Let’s start with Count Encoder. 

Image by author
count_selected_df = test_df[
    (test_df.poutcome_success > 0.50) | 
    ((test_df.poutcome_success <= 0.50) & (test_df.age > 60.50)) | 
    ((test_df.region_count > 3645.50) & (test_df.region_count <= 8151.50) & 
         (test_df.poutcome_success <= 0.50) & (test_df.contact_cellular > 0.50) & (test_df.age <= 60.50))
]

print(count_selected_df.shape[0], count_selected_df.y.sum())
# (508, 227)

We can also see what regions have been selected, and it’s only Manchester.

Image by author

Let’s continue with the Target encoding. 

Image by author
target_selected_df = test_df[
    ((test_df.region_target > 0.21) & (test_df.poutcome_success > 0.50)) | 
    ((test_df.region_target > 0.21) & (test_df.poutcome_success <= 0.50) & (test_df.month <= 6.50) & (test_df.housing <= 0.50) & (test_df.contact_unknown <= 0.50)) | 
    ((test_df.region_target > 0.21) & (test_df.poutcome_success <= 0.50) & (test_df.month > 8.50) & (test_df.housing <= 0.50) 
         & (test_df.contact_unknown <= 0.50)) |
    ((test_df.region_target <= 0.21) & (test_df.poutcome_success > 0.50)) |
    ((test_df.region_target > 0.21) & (test_df.poutcome_success <= 0.50) & (test_df.month > 6.50) & (test_df.month <= 8.50) 
         & (test_df.housing <= 0.50) & (test_df.contact_unknown <= 0.50))
]

print(target_selected_df.shape[0], target_selected_df.y.sum())
# (502, 248)

We see a slightly lower number of selected users for communication but a significantly higher number of conversions: 248 vs. 227 (+9.3%).

Let’s also look at the selected categories. We see that the model picked up all the cities with high conversions (Manchester, Liverpool, Bristol, Leicester, and New Castle), but there are also many small regions with high conversions solely due to chance.

region_encoding_df[region_encoding_df.region_target > 0.21]\
  .sort_values('region_count', ascending = False)
Image by author

In our case, it doesn’t impact much since the share of such small cities is low. However, if you have way more small categories, you might see significant drawbacks of overfitting. Target Encoding might be tricky at this point, so it’s worth keeping an eye on the output of your model. 

Luckily, there’s an approach that can help you overcome this issue. Following the article “Encoding Categorical Variables: A Deep Dive into Target Encoding”, we can add smoothing. The idea is to combine the group’s conversion rate with the overall average: the larger the group, the more weight its data carries, while smaller segments will lean more towards the global average.

First, I’ve selected the parameters that make sense for our distribution, looking at a bunch of options. I chose to use the global average for the groups under 100 people. This part is a bit subjective, so use common sense and your knowledge about the business domain.

import numpy as np
import matplotlib.pyplot as plt

global_mean = train_df.y.mean()

k = 100
f = 10
smooth_df = pd.DataFrame({'region_count':np.arange(1, 100001, 1) })
smooth_df['smoothing'] = (1 / (1 + np.exp(-(smooth_df.region_count - k) / f)))

ax = plt.scatter(smooth_df.region_count, smooth_df.smoothing)
plt.xscale('log')
plt.ylim([-.1, 1.1])
plt.title('Smoothing')
Image by author

Then, we can calculate, based on the selected parameters, the smoothing coefficients and blended averages.

region_encoding_df['smoothing'] = (1 / (1 + np.exp(-(region_encoding_df.region_count - k) / f)))
region_encoding_df['region_target'] = region_encoding_df.smoothing * region_encoding_df.raw_region_target \
    + (1 - region_encoding_df.smoothing) * global_mean

Then, we can fit another model with smoothed target category encoding.

train_df = train_df.merge(region_encoding_df[['region', 'region_target']], 
  on = 'region')
test_df = test_df.merge(region_encoding_df[['region', 'region_target']], 
  on = 'region', how = 'left')
test_df['region_target'] = test_df['region_target']\
  .fillna(region_encoding_df.region_target.mean())

target_v2_feature_names = train_df.drop(['y', 'id', 'region'], axis = 1)\
  .columns

target_v2_model = sklearn.tree.DecisionTreeClassifier(min_samples_leaf = 500, 
  min_impurity_decrease=0.001)
target_v2_model.fit(train_df[target_v2_feature_names], train_df['y'])
target_v2_model_def_df = get_model_definition(target_v2_model, 
  target_v2_feature_names)
Image by author
target_v2_selected_df = test_df[
    ((test_df.region_target > 0.12) & (test_df.poutcome_success > 0.50)) | 
    ((test_df.region_target > 0.12) & (test_df.poutcome_success <= 0.50) & (test_df.month <= 6.50) & (test_df.housing <= 0.50) & (test_df.contact_unknown <= 0.50)) | 
    ((test_df.region_target > 0.12) & (test_df.poutcome_success <= 0.50) & (test_df.month > 8.50) & (test_df.housing <= 0.50) 
         & (test_df.contact_unknown <= 0.50)) | 
    ((test_df.region_target <= 0.12) & (test_df.poutcome_success > 0.50) ) | 
    ((test_df.region_target > 0.12) & (test_df.poutcome_success <= 0.50) & (test_df.month > 6.50) & (test_df.month <= 8.50) 
         & (test_df.housing <= 0.50) & (test_df.contact_unknown <= 0.50) )
]

target_v2_selected_df.shape[0], target_v2_selected_df.y.sum()
# (500, 247)

We can see that we’ve eliminated the small cities and prevented overfitting in our model while keeping roughly the same performance, capturing 247 conversions.

region_encoding_df[region_encoding_df.region_target > 0.12]
Image by author

You can also use TargetEncoder from sklearn, which smoothes and mixes the category and global means depending on the segment size. However, it also adds random noise, which is not ideal for our case of heuristics.

You can find the full code on GitHub.

Summary

In this article, we explored how to extract simple “rules” from data and use them to inform business decisions. We generated heuristics using a Decision Tree Classifier and touched on the important topic of categorical encoding since decision tree algorithms require categorical variables to be converted.

We saw that this rule-based approach can be surprisingly effective, helping you reach business decisions quickly. However, it’s worth noting that this simplistic approach has its drawbacks:

  • We are trading off the model’s power and accuracy for its simplicity and interpretability, so if you’re optimising for accuracy, choose another approach.
  • Even though we’re using a set of static heuristics, your data still can change, and they might become outdated, so you need to recheck your model from time to time. 

Thank you a lot for reading this article. I hope it was insightful to you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

Dataset: Moro, S., Rita, P., & Cortez, P. (2014). Bank Marketing [Dataset]. UCI Machine Learning Repository. https://doi.org/10.24432/C5K306

The post Mining Rules from Data appeared first on Towards Data Science.

]]>
The Next Frontier in LLM Accuracy https://towardsdatascience.com/the-next-frontier-in-llm-accuracy-cb2491a740d4/ Sat, 04 Jan 2025 12:02:35 +0000 https://towardsdatascience.com/the-next-frontier-in-llm-accuracy-cb2491a740d4/ Exploring the Power of Lamini Memory Tuning

The post The Next Frontier in LLM Accuracy appeared first on Towards Data Science.

]]>
Image generated by DALL-E 3
Image generated by DALL-E 3

Accuracy is often critical for Llm applications, especially in cases such as API calling or summarisation of financial reports. Fortunately, there are ways to enhance precision. The best practices to improve accuracy include the following steps:

  • You can start simply with prompt engineering techniques – adding more detailed instructions, using few-shot prompting, or asking the model to think step-by-step.
  • If accuracy is still insufficient, you can incorporate a self-reflection step, for example, to return errors from the API calls and ask the LLM to correct mistakes.
  • The next option is to provide the most relevant context to the LLM using RAG (Retrieval-Augmented Generation) to boost precision further.

We’ve explored this approach in my previous TDS article, "From Prototype to Production: Enhancing LLM Accuracy". In that project, we built an SQL Agent and went from 0% valid SQL queries to 70% accuracy. However, there are limits to what we can achieve with prompt. To break through this barrier and reach the next frontier of accuracy, we need to adopt more advanced techniques.

The most promising option is fine-tuning. With fine-tuning, we can move from relying solely on information in prompts to embedding additional information directly into the model’s weights.

Fine-tuning

Let’s start by understanding what fine-tuning is. Fine-tuning is the process of refining pre-trained models by training them on smaller, task-specific datasets to enhance their performance in particular applications. Basic models are initially trained on vast amounts of data, which allows them to develop a broad understanding of language. Fine-tuning, however, tailors these models to specialized tasks, transforming them from general-purpose systems into highly targeted tools. For example, instruction fine-tuning taught GPT-2 to chat and follow instructions, and that’s how ChatGPT emerged.

Basic LLMs are initially trained to predict the next token based on vast text corpora. Fine-tuning typically adopts a supervised approach, where the model is presented with specific questions and corresponding answers, allowing it to adjust its weights to improve accuracy.

Historically, fine-tuning required updating all model weights, a method known as full fine-tuning. This process was computationally expensive since it required storing all the model weights, states, gradients and forward activations in memory. To address these challenges, parameter-efficient fine-tuning techniques were introduced. PEFT methods update only the small set of the model parameters while keeping the rest frozen. Among these methods, one of the most widely adopted is LoRA (Low-Rank Adaptation), which significantly reduces the computational cost without compromising performance.

Pros & cons

Before considering fine-tuning, it’s essential to weigh its advantages and limitations.

Advantages:

  • Fine-tuning enables the model to learn and retain significantly more information than can be provided through prompts alone.
  • It usually gives higher accuracy, often exceeding 90%.
  • During inference, it can reduce costs by enabling the use of smaller, task-specific models instead of larger, general-purpose ones.
  • Fine-tuned small models can often be deployed on-premises, eliminating reliance on cloud providers such as OpenAI or Anthropic. This approach reduces costs, enhances privacy, and minimizes dependency on external infrastructure.

Disadvantages:

  • Fine-tuning requires upfront investments for model training and data preparation.
  • It requires specific technical knowledge and may involve a steep learning curve.
  • The quality of results depends heavily on the availability of high-quality training data.

Since this project is focused on gaining knowledge, we will proceed with fine-tuning. However, in real-world scenarios, it’s important to evaluate whether the benefits of fine-tuning justify all the associated costs and efforts.

Execution

The next step is to plan how we will approach fine-tuning. After listening to the "Improving Accuracy of LLM Applications" course, I’ve decided to try the Lamini platform for the following reasons:

  • It offers a simple one-line API call to fine-tune the model. It’s especially convenient since we’re just starting to learn a new technique.
  • Although it’s not free and can be quite expensive for toy projects (at $1 per tuning step), they offer free credits upon registration, which are sufficient for initial testing.
  • Lamini has implemented a new approach, Lamini Memory Tuning, which promises zero loss of factual accuracy while preserving general capabilities. This is a significant claim, and it’s worth testing out. We will discuss this approach in more detail shortly.

Of course, there are lots of other fine-tuning options you can consider:

  • The Llama documentation provides numerous recipes for fine-tuning, which can be executed on a cloud server or even locally for smaller models.
  • There are many step-by-step guides available online, including the tutorial on how to fine-tune Llama on Kaggle from DataCamp.
  • You can fine-tune not only open-sourced models. OpenAI also offers the capability to fine-tune their models.

Lamini Memory Tuning

As I mentioned earlier, Lamini released a new approach to fine-tuning, and I believe it’s worth discussing it in more detail.

Lamini introduced the Mixture of Memory Experts (MoME) approach, which enables LLMs to learn a vast amount of factual information with almost zero loss, all while maintaining generalization capabilities and requiring a feasible amount of computational resources.

To achieve this, Lamini extended a pre-trained LLM by adding a large number (on the order of 1 million) of LoRA adapters along with a cross-attention layer. Each LoRA adapter is a memory expert, functioning as a type of memory for the model. These memory experts specialize in different aspects, ensuring that the model retains faithful and accurate information from the data it was tuned on. Inspired by information retrieval, these million memory experts are equivalent to indices from which the model intelligently retrieves and routes.

At inference time, the model retrieves a subset of the most relevant experts at each layer and merges back into the base model to generate a response to the user query.

Figure from the paper by Li et al. 2024 | source
Figure from the paper by Li et al. 2024 | source

Lamini Memory Tuning is said to be capable of achieving 95% accuracy. The key difference from traditional instruction fine-tuning is that instead of optimizing for average error across all tasks, this approach focuses on achieving zero error for the facts the model is specifically trained to remember.

Figure from the paper by Li et al. 2024 | source
Figure from the paper by Li et al. 2024 | source

So, this approach allows an LLM to preserve its ability to generalize with average error on everything else while recalling the important facts nearly perfectly.

For further details, you can refer to the research paper "Banishing LLM Hallucinations Requires Rethinking Generalization" by Li et al. (2024)

Lamini Memory Tuning holds great promise – let’s see if it delivers on its potential in practice.

Setup

As always, let’s begin by setting everything up. As we discussed, we’ll be using Lamini to fine-tune Llama, so the first step is to install the Lamini package.

pip install lamini

Additionally, we need to set up the Lamini API Key on their website and specify it as an environment variable.

export LAMINI_API_KEY="<YOUR-LAMINI-API-KEY>"

As I mentioned above, we will be improving the SQL Agent, so we need a database. For this example, we’ll continue using ClickHouse, but feel free to choose any database that suits your needs. You can find more details on the ClickHouse setup and the database schema in the previous article.

Creating a training dataset

To fine-tune an LLM, we first need a dataset – in our case, a set of pairs of questions and answers (SQL queries). The task of putting together a dataset might seem daunting, but luckily, we can leverage LLMs to do it.

The key factors to consider while preparing the dataset:

  • The quality of the data is crucial, as we will ask the model to remember these facts.
  • Diversity in the examples is important so that a model can learn how to handle different cases.
  • It’s preferable to use real data rather than synthetically generated data since it better represents real-life questions.
  • The usual minimum size for a fine-tuning dataset is around 1,000 examples, but the more high-quality data, the better.

Generating examples

All the information required to create question-and-answer pairs is present in the database schema, so it will be a feasible task for an LLM to generate examples. Additionally, I have a representative set of Q&A pairs that I used for RAG approach, which we can present to the LLM as examples of valid queries (using the few-shot prompting technique). Let’s load the RAG dataset.

# loading a set of examples
with open('rag_set.json', 'r') as f:
    rag_set = json.loads(f.read())

rag_set_df = pd.DataFrame(rag_set)

rag_set_df['qa_fmt'] = list(map(
    lambda x, y: "question: %s, sql_query: %s" % (x, y),
    rag_set_df.question,
    rag_set_df.sql_query
))

The idea is to iteratively provide the LLM with the schema information and a set of random examples (to ensure diversity in the questions) and ask it to generate a new, similar, but different Q&A pair.

Let’s create a system prompt that includes all the necessary details about the database schema.

generate_dataset_system_prompt = '''
You are a senior data analyst with more than 10 years of experience writing complex SQL queries. 
There are two tables in the database you're working with with the following schemas. 

Table: ecommerce.users 
Description: customers of the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- country (string) - country of residence, for example, "Netherlands" or "United Kingdom"
- is_active (integer) - 1 if customer is still active and 0 otherwise
- age (integer) - customer age in full years, for example, 31 or 72

Table: ecommerce.sessions 
Description: sessions for online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- session_id (integer) - unique identifier of session, for example, 106 or 1023
- action_date (date) - session start date, for example, "2021-01-03" or "2024-12-02"
- session_duration (integer) - duration of session in seconds, for example, 125 or 49
- os (string) - operation system that customer used, for example, "Windows" or "Android"
- browser (string) - browser that customer used, for example, "Chrome" or "Safari"
- is_fraud (integer) - 1 if session is marked as fraud and 0 otherwise
- revenue (float) - income in USD (the sum of purchased items), for example, 0.0 or 1506.7

Write a query in ClickHouse SQL to answer the following question. 
Add "format TabSeparatedWithNames" at the end of the query to get data from ClickHouse database in the right format. 
'''

The next step is to create a template for the user query.

generate_dataset_qa_tmpl = '''
Considering the following examples, please, write question 
and SQL query to answer it, that is similar but different to provided below.

Examples of questions and SQL queries to answer them: 
{examples}
'''

Since we need a high-quality dataset, I prefer using a more advanced model – GPT-4o— rather than Llama. As usual, I’ll initialize the model and create a dummy tool for structured output.

from langchain_core.tools import tool

@tool
def generate_question_and_answer(comments: str, question: str, sql_query: str) -> str:
  """Returns the new question and SQL query 

  Args:
      comments (str): 1-2 sentences about the new question and answer pair,
      question (str): new question 
      sql_query (str): SQL query in ClickHouse syntax to answer the question
  """
  pass

from langchain_openai import ChatOpenAI
generate_qa_llm = ChatOpenAI(model="gpt-4o", temperature = 0.5)
  .bind_tools([generate_question_and_answer])

Now, let’s combine everything into a function that will generate a Q&A pair and create a set of examples.

# helper function to combine system + user prompts
def get_openai_prompt(question, system):
    messages = [
        ("system", system),
        ("human", question)
    ]
    return messages

def generate_qa():
  # selecting 3 random examples 
  sample_set_df = rag_set_df.sample(3)
  examples = 'nn'.join(sample_set_df.qa_fmt.values)

  # constructing prompt
  prompt = get_openai_prompt(
    generate_dataset_qa_tmpl.format(examples = examples), 
    generate_dataset_system_prompt)

  # calling LLM
  qa_res = generate_qa_llm.invoke(prompt)

  try:
      rec = qa_res.tool_calls[0]['args']
      rec['examples'] = examples
      return rec
  except:
      pass

# executing function
qa_tmp = []
for i in tqdm.tqdm(range(2000)):
  qa_tmp.append(generate_qa())

new_qa_df = pd.DataFrame(qa_tmp)

I generated 2,000 examples, but in reality, I used a much smaller dataset for this toy project. Therefore, I recommend limiting the number of examples to 200–300.

Cleaning the dataset

As we know, "garbage in, garbage out", so an essential step before fine-tuning is cleaning the data generated by the LLM.

The first – and most obvious – check is to ensure that each SQL query is valid.

def is_valid_output(s):
    if s.startswith('Database returned the following error:'):
        return 'error'
    if len(s.strip().split('n')) >= 1000:
        return 'too many rows'
    return 'ok'

new_qa_df['output'] = new_qa_df.sql_query.map(get_clickhouse_data)
new_qa_df['is_valid_output'] = new_qa_df.output.map(is_valid_output)

There are no invalid SQL queries, but some questions return over 1,000 rows.

Although these cases are valid, we’re focusing on an OLAP scenario with aggregated stats, so I’ve retained only queries that return 100 or fewer rows.

new_qa_df['output_rows'] = new_qa_df.output.map(
  lambda x: len(x.strip().split('n')))

filt_new_qa_df = new_qa_df[new_qa_df.output_rows <= 100]

I also eliminated cases with empty output – queries that return no rows or only the header.

filt_new_qa_df = filt_new_qa_df[filt_new_qa_df.output_rows > 1]

Another important check is for duplicate questions. The same question with different answers could confuse the model, as it won’t be able to tune to both solutions simultaneously. And in fact, we have such cases.

filt_new_qa_df = filt_new_qa_df[['question', 'sql_query']].drop_duplicates()
filt_new_qa_df['question'].value_counts().head(10)

To resolve these duplicates, I’ve kept only one answer for each question.

filt_new_qa_df = filt_new_qa_df.drop_duplicates('question') 

Although I generated around 2,000 examples, I’ve decided to use a smaller dataset of 200 question-and-answer pairs. Fine-tuning with a larger dataset would require more tuning steps and be more expensive.

sample_dataset_df = pd.read_csv('small_sample_for_finetuning.csv', sep = 't')

You can find the final training dataset on GitHub.

Now that our training dataset is ready, we can move on to the most exciting part – fine-tuning.

Fine-tuning

The first iteration

The next step is to generate the sets of requests and responses for the LLM that we will use to fine-tune the model.

Since we’ll be working with the Llama model, let’s create a helper function to construct a prompt for it.

def get_llama_prompt(user_message, system_message=""):
    system_prompt = ""
    if system_message != "":
        system_prompt = (
            f"<|start_header_id|>system<|end_header_id|>nn{system_message}"
            f"<|eot_id|>"
        )
    prompt = (f"<|begin_of_text|>{system_prompt}"
              f"<|start_header_id|>user<|end_header_id|>nn"
              f"{user_message}"
              f"<|eot_id|>"
              f"<|start_header_id|>assistant<|end_header_id|>nn"
         )
    return prompt  

For requests, we will use the following system prompt, which includes all the necessary information about the data schema.

generate_query_system_prompt = '''
You are a senior data analyst with more than 10 years of experience writing complex SQL queries. 
There are two tables in the database you're working with with the following schemas. 

Table: ecommerce.users 
Description: customers of the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- country (string) - country of residence, for example, "Netherlands" or "United Kingdom"
- is_active (integer) - 1 if customer is still active and 0 otherwise
- age (integer) - customer age in full years, for example, 31 or 72

Table: ecommerce.sessions 
Description: sessions of usage the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- session_id (integer) - unique identifier of session, for example, 106 or 1023
- action_date (date) - session start date, for example, "2021-01-03" or "2024-12-02"
- session_duration (integer) - duration of session in seconds, for example, 125 or 49
- os (string) - operation system that customer used, for example, "Windows" or "Android"
- browser (string) - browser that customer used, for example, "Chrome" or "Safari"
- is_fraud (integer) - 1 if session is marked as fraud and 0 otherwise
- revenue (float) - income in USD (the sum of purchased items), for example, 0.0 or 1506.7

Write a query in ClickHouse SQL to answer the following question. 
Add "format TabSeparatedWithNames" at the end of the query to get data from ClickHouse database in the right format. 
Answer questions following the instructions and providing all the needed information and sharing your reasoning. 
'''

Let’s create the responses in the format suitable for Lamini fine-tuning. We need to prepare a list of dictionaries with input and output keys.

formatted_responses = []

for rec in sample_dataset_df.to_dict('records'):
  formatted_responses.append(
    {
      'input': get_llama_prompt(rec['question'], 
        generate_query_system_prompt),
      'output': rec['sql_query']
    }
  )

Now, we are fully prepared for fine-tuning. We just need to select a model and initiate the process. We will be fine-tuning the Llama 3.1 8B model.

from lamini import Lamini
llm = Lamini(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct")

finetune_args = {
    "max_steps": 50,
    "learning_rate": 0.0001
}

llm.train(
  data_or_dataset_id=formatted_responses,
  finetune_args=finetune_args,
)  

We can specify several hyperparameters, and you can find all the details in the Lamini documentation. For now, I’ve passed only the most essential ones to the function:

  • max_steps: This determines the number of tuning steps. The documentation recommends using 50 steps for experimentation to get initial results without spending too much money.
  • learning_rate: This parameter determines the step size of each iteration while moving toward a minimum of a loss function (Wikipedia). The default is 0.0009, but based on the guidance, I’ve decided to use a smaller value.

Now, we just need to wait for 10–15 minutes while the model trains, and then we can test it.

finetuned_llm = Lamini(model_name='<model_id>')
# you can find Model ID in the Lamini interface

question = '''How many customers made purchase in December 2024?'''
prompt = get_llama_prompt(question, generate_query_system_prompt)
finetuned_llm.generate(prompt, max_new_tokens=200)
# select uniqExact(s.user_id) as customers 
# from ecommerce.sessions s join ecommerce.users u 
# on s.user_id = u.user_id 
# where (toStartOfMonth(action_date) = '2024-12-01') and (revenue > 0) 
# format TabSeparatedWithNames

It’s worth noting that we’re using Lamini for inference as well and will have to pay for it. You can find up-to-date information about the costs here.

At first glance, the result looks promising, but we need a more robust accuracy evaluation to confirm it.

Additionally, it’s worth noting that since we’ve fine-tuned the model for our specific task, it now consistently returns SQL queries, meaning we may no longer need to use tool calls for structured output.

Evaluating the quality

We’ve discussed LLM accuracy evaluation in detail in my previous article, so here I’ll provide a brief recap.

We use a golden set of question-and-answer pairs to evaluate the model’s quality. Since this is a toy example, I’ve limited the set to just 10 pairs, which you can review on GitHub.

The evaluation process consists of two parts:

  • SQL Query Validity: First, we check that the SQL query is valid, meaning ClickHouse doesn’t return errors during execution.
  • Query Correctness: Next, we ensure that the generated query is correct. We compare the outputs of the generated and true queries using LLMs to verify that they provide semantically identical results.

The initial results are far from ideal, but they are significantly better than the base Llama model (which produced zero valid SQL queries). Here’s what we found:

  • ClickHouse returned errors for two queries.
  • Three queries were executed, but the results were incorrect.
  • Five queries were correct.

No surprises – there’s no silver bullet, and it’s always an iterative process. Let’s investigate what went wrong.

Diving into the errors

The approach is straightforward. Let’s examine the errors one by one to understand why we got these results and how we can fix them. We’ll start with the first unsuccessful example.


Question: Which country had the highest number of first-time users in 2024?

Golden query:

select 
  country, 
  count(distinct user_id) as users 
from 
  (
    select user_id, min(action_date) as first_date 
    from ecommerce.sessions 
    group by user_id 
    having toStartOfYear(first_date) = '2024-01-01'
  ) as t 
  inner join ecommerce.users as u 
    on t.user_id = u.user_id 
group by country 
order by users desc 
limit 1 
format TabSeparatedWithNames 

Generated query:

select 
  country, 
  count(distinct u.user_id) as first_time_users 
from ecommerce.sessions s 
join ecommerce.users u 
  on s.user_id = u.user_id 
where (toStartOfYear(action_date) = '2024-01-01') 
  and (s.session_id = 1) 
group by country 
order by first_time_users desc 
limit 1 
format TabSeparatedWithNames 

The query is valid, but it returns an incorrect result. The issue lies in the model’s assumption that the first session for each user will always have session_id = 1. Since Lamini Memory Tuning allows the model to learn facts from the training data, let’s investigate why the model made this assumption. Likely, it’s in our data.

Let’s review all the examples that mention first. I’ll use broad and simple search criteria to get a high-level view.

As we can see, there are no examples mentioning first-time users – only references to the first quarter. It’s no surprise that the model wasn’t able to capture this concept. The solution is straightforward: we just need to add a set of examples with questions and answers specifically about first-time users.


Let’s move on to the next problematic case.

Question: What was the fraud rate in 2023, expressed as a percentage?

Golden query:

select 
  100*uniqExactIf(user_id, is_fraud = 1)/uniqExact(user_id) as fraud_rate 
from ecommerce.sessions 
where (toStartOfYear(action_date) = '2023-01-01') 
format TabSeparatedWithNames 

Generated query:

select 
  100*countIf(is_fraud = 1)/count() as fraud_rate 
from ecommerce.sessions 
where (toStartOfYear(action_date) = '2023-01-01') 
format TabSeparatedWithNames 

Here’s another misconception: we assumed that the fraud rate is based on the share of users, while the model calculated it based on the share of sessions.

Let’s check the examples related to the fraud rate in the data. There are two cases: one calculates the share of users, while the other calculates the share of sessions.

To fix this issue, I corrected the incorrect answer and added more accurate examples involving fraud rate calculations.


I’d like to discuss another incorrect case, as it will highlight an important aspect of the process of resolving these issues.

Question: What are the median and interquartile range (IQR) of purchase revenue for each country?

Golden query:

select 
  country, 
  median(revenue) as median_revenue, 
  quantile(0.25)(revenue) as percentile_25_revenue, 
  quantile(0.75)(revenue) as percentile_75_revenue 
from ecommerce.sessions AS s 
inner join ecommerce.users AS u 
  on u.user_id = s.user_id 
where (revenue > 0) 
group by country 
format TabSeparatedWithNames 

Generated query:

select 
  country, 
  median(revenue) as median_revenue, 
  quantile(0.25)(revenue) as percentile_25_revenue, 
  quantile(0.75)(revenue) as percentile_75_revenue 
from ecommerce.sessions s join ecommerce.users u 
  on s.user_id = u.user_id 
group by country 
format TabSeparatedWithNames 

When inspecting the problem, it’s crucial to focus on the model’s misconceptions or incorrect assumptions. For example, in this case, there may be a temptation to add examples similar to those in the golden dataset, but that would be too specific. Instead, we should address the actual root cause of the model’s misconception:

  • It understood the concepts of median and IQR quite well.
  • The split by country is also correct.
  • However, it misinterpreted the concept of "purchase revenue", including sessions where there was no purchase at all (revenue = 0).

So, we need to ensure that our datasets contain enough information related to purchase revenue. Let’s take a look at what we have now. There’s only one example, and it’s incorrect.

Let’s fix this example and add more cases of purchase revenue calculations.


Using a similar approach, I’ve added more examples for the two remaining incorrect queries and compiled an updated, cleaned version of the training dataset. You can find it on GitHub. With this, our data is ready to the next iteration.

Another iteration of fine-tuning

Before diving into fine-tuning, it’s essential to double-check the quality of the training dataset by ensuring that all SQL queries are valid.

clean_sample_dataset_df = pd.read_csv(
  'small_sample_for_finetuning_cleaned.csv', sep = 't', 
  on_bad_lines = 'warn')

clean_sample_dataset_df['output'] = clean_sample_dataset_df.sql_query
  .map(lambda x: get_clickhouse_data(str(x)))
clean_sample_dataset_df['is_valid_output'] = clean_sample_dataset_df['output']
  .map(is_valid_output)
print(clean_sample_dataset_df.is_valid_output.value_counts())

# is_valid_output
# ok    241

clean_formatted_responses = []
for rec in clean_sample_dataset_df.to_dict('records'):
  clean_formatted_responses.append(
    {
      'input': get_llama_prompt(
        rec['question'], 
        generate_query_system_prompt),
      'output': rec['sql_query']
    }
  )

Now that we’re confident in the data, we can proceed with fine-tuning. This time, I’ve decided to train it for 150 steps to achieve better accuracy.

finetune_args = {
      "max_steps": 150,
      "learning_rate": 0.0001
}

llm = Lamini(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct")
llm.train(
  data_or_dataset_id=clean_formatted_responses,
  finetune_args=finetune_args
)

After waiting a bit longer than last time, we now have a new fine-tuned model with nearly zero loss after 150 tuning steps.

We can run the evaluation again and see much better results. So, our approach is working.

The result is astonishing, but it’s still worth examining the incorrect example to understand what went wrong. We got an incorrect result for the question we discussed earlier: "What are the median and interquartile range (IQR) of purchase revenue for each country?" However, this time, the model generated a query that is exactly identical to the one in the golden set.

select 
  country, 
  median(revenue) as median_revenue, 
  quantile(0.25)(revenue) as percentile_25_revenue, 
  quantile(0.75)(revenue) as percentile_75_revenue 
from ecommerce.sessions AS s 
inner join ecommerce.users AS u 
  on u.user_id = s.user_id 
where (s.revenue > 0) 
group by country 
format TabSeparatedWithNames

So, the issue actually lies in our evaluation. In fact, if you try to execute this query multiple times, you’ll notice that the results are slightly different each time. The root cause is that the quantile function in ClickHouse computes approximate values using reservoir sampling, which is why we’re seeing varying results. We could have used quantileExact instead to get more consistent numbers.

That said, this means that fine-tuning has allowed us to achieve 100% accuracy. Even though our toy golden dataset consists of just 10 questions, this is a tremendous achievement. We’ve progressed all the way from zero valid queries with vanilla Llama to 70% accuracy with RAG and self-reflection, and now, thanks to Lamini Memory Tuning, we’ve reached 100% accuracy.

You can find the full code on GitHub.

Summary

In this article, we continued exploring techniques to improve LLM accuracy:

  • After trying RAG and self-reflection in the previous article, we moved on to a more advanced technique – fine-tuning.
  • We experimented with Memory Tuning developed by Lamini, which enables a model to remember a large volume of facts with near-zero errors.
  • In our example, Memory Tuning performed exceptionally well, and we achieved 100% accuracy on our evaluation set of 10 questions.

Thank you a lot for reading this article. I hope this article was insightful for you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

This article is inspired by the "Improving Accuracy of LLM Applications" short course from DeepLearning.AI.


Disclaimer: I am not affiliated with Lamini in any way. The views expressed in this article are solely my own, based on independent testing and evaluation of the Lamini platform. This post is intended for educational purposes and does not constitute an endorsement of any specific tool or service.

The post The Next Frontier in LLM Accuracy appeared first on Towards Data Science.

]]>
From Prototype to Production: Enhancing LLM Accuracy https://towardsdatascience.com/from-prototype-to-production-enhancing-llm-accuracy-791d79b0af9b/ Thu, 19 Dec 2024 20:32:55 +0000 https://towardsdatascience.com/from-prototype-to-production-enhancing-llm-accuracy-791d79b0af9b/ Implementing evaluation frameworks to optimize accuracy in real-world applications

The post From Prototype to Production: Enhancing LLM Accuracy appeared first on Towards Data Science.

]]>
Building a prototype for an Llm application is surprisingly straightforward. You can often create a functional first version within just a few hours. This initial prototype will likely provide results that look legitimate and be a good tool to demonstrate your approach. However, this is usually not enough for production use.

LLMs are probabilistic by nature, as they generate tokens based on the distribution of likely continuations. This means that in many cases, we get the answer close to the "correct" one from the distribution. Sometimes, this is acceptable – for example, it doesn’t matter whether the app says "Hello, John!" or "Hi, John!". In other cases, the difference is critical, such as between "The revenue in 2024 was 20M USD" and "The revenue in 2024 was 20M GBP".

In many real-world business scenarios, precision is crucial, and "almost right" isn’t good enough. For example, when your LLM application needs to execute API calls, or you’re doing a summary of financial reports. From my experience, ensuring the accuracy and consistency of results is far more complex and time-consuming than building the initial prototype.

In this article, I will discuss how to approach measuring and improving accuracy. We’ll build an Sql Agent where precision is vital for ensuring that queries are executable. Starting with a basic prototype, we’ll explore methods to measure accuracy and test various techniques to enhance it, such as self-reflection and retrieval-augmented generation (RAG).

Setup

As usual, let’s begin with the setup. The core components of our SQL agent solution are the LLM model, which generates queries, and the SQL database, which executes them.

LLM model – Llama

For this project, we will use an open-source Llama model released by Meta. I’ve chosen Llama 3.1 8B because it is lightweight enough to run on my laptop while still being quite powerful (refer to the documentation for details).

If you haven’t installed it yet, you can find guides here. I use it locally on MacOS via Ollama. Using the following command, we can download the model.

ollama pull llama3.1:8b

We will use Ollama with LangChain, so let’s start by installing the required package.

pip install -qU langchain_ollama 

Now, we can run the Llama model and see the first results.

from langchain_ollama import OllamaLLM

llm = OllamaLLM(model="llama3.1:8b")
llm.invoke("How are you?")
# I'm just a computer program, so I don't have feelings or emotions 
# like humans do. I'm functioning properly and ready to help with 
# any questions or tasks you may have! How can I assist you today?

We would like to pass a system message alongside customer questions. So, following the Llama 3.1 model documentation, let’s put together a helper function to construct a prompt and test this function.

def get_llama_prompt(user_message, system_message=""):
  system_prompt = ""
  if system_message != "":
    system_prompt = (
      f"<|start_header_id|>system<|end_header_id|>nn{system_message}"
      f"<|eot_id|>"
    )
  prompt = (f"<|begin_of_text|>{system_prompt}"
            f"<|start_header_id|>user<|end_header_id|>nn"
            f"{user_message}"
            f"<|eot_id|>"
            f"<|start_header_id|>assistant<|end_header_id|>nn"
           )
  return prompt   

system_prompt = '''
You are Rudolph, the spirited reindeer with a glowing red nose, 
bursting with excitement as you prepare to lead Santa's sleigh 
through snowy skies. Your joy shines as brightly as your nose, 
eager to spread Christmas cheer to the world!
Please, answer questions concisely in 1-2 sentences.
'''
prompt = get_llama_prompt('How are you?', system_prompt)
llm.invoke(prompt)

# I'm feeling jolly and bright, ready for a magical night! 
# My shiny red nose is glowing brighter than ever, just perfect 
# for navigating through the starry skies. 

The new system prompt has changed the answer significantly, so it works. With this, our local LLM setup is ready to go.

Database – ClickHouse

I will use an open-source database ClickHouse. I’ve chosen ClickHouse because it has a specific SQL dialect. LLMs have likely encountered fewer examples of this dialect during training, making the task a bit more challenging. However, you can choose any other database.

Installing ClickHouse is pretty straightforward – just follow the instructions provided in the documentation.

We will be working with two tables: ecommerce.users and ecommerce.sessions. These tables contain fictional data, including customer personal information and their session activity on the e-commerce website.

You can find the code for generating synthetic data and uploading it on GitHub.

With that, the setup is complete, and we’re ready to move on to building the basic prototype.

The first prototype

As discussed, our goal is to build an SQL Agent – an application that generates SQL queries to answer customer questions. In the future, we can add another layer to this system: executing the SQL query, passing both the initial question and the database results back to the LLM, and asking it to generate a human-friendly answer. However, for this article, we’ll focus on the first step.

The best practice with LLM applications (similar to any other complex tasks) is to start simple and then iterate. The most straightforward implementation is to do one LLM call and share all the necessary information (such as schema description) in the system prompt. So, the first step is to put together the prompt.

generate_query_system_prompt = '''
You are a senior data analyst with more than 10 years of experience writing complex SQL queries. 
There are two tables in the database with the following schemas. 

Table: ecommerce.users 
Description: customers of the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- country (string) - country of residence, for example, "Netherlands" or "United Kingdom"
- is_active (integer) - 1 if customer is still active and 0 otherwise
- age (integer) - customer age in full years, for example, 31 or 72

Table: ecommerce.sessions 
Description: sessions of usage the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- session_id (integer) - unique identifier of session, for example, 106 or 1023
- action_date (date) - session start date, for example, "2021-01-03" or "2024-12-02"
- session_duration (integer) - duration of session in seconds, for example, 125 or 49
- os (string) - operation system that customer used, for example, "Windows" or "Android"
- browser (string) - browser that customer used, for example, "Chrome" or "Safari"
- is_fraud (integer) - 1 if session is marked as fraud and 0 otherwise
- revenue (float) - income in USD (the sum of purchased items), for example, 0.0 or 1506.7

Write a query in ClickHouse SQL to answer the following question. 
Add "format TabSeparatedWithNames" at the end of the query to get data from ClickHouse database in the right format. 
'''

I’ve included the example values for each field in the prompt to ensure that LLM understands the data format.

And that’s it! With this, we have our first functional prototype for the SQL Agent. Now, it’s time to put it to the test.

prompt = get_llama_prompt('How many customers made purchase in December 2024?', 
  generate_query_system_prompt)
llm.invoke(prompt)

# To find out how many customers made a purchase in December 2024, 
# we need to join the `sessions` table with the `users` table on the `user_id`, 
# and then filter by `action_date` for sessions that took place in December 2024. 

# Here is the query:
# 
# ```sql
# SELECT COUNT(DISTINCT u.user_id)
# FROM ecommerce.sessions s
# JOIN ecommerce.users u ON s.user_id = u.user_id
# WHERE 
#     EXTRACT(YEAR FROM s.action_date) = 2024
#     AND EXTRACT(MONTH FROM s.action_date) = 12
#     AND revenue > 0;
# ```

# This query first joins the `sessions` table with the `users` table 
# on the `user_id`. Then it filters out sessions where no purchase 
# was made (revenue is greater than 0). Finally, it selects the distinct 
# user IDs of these customers.

# When you run this query in ClickHouse SQL, you can add "format TabSeparatedWithNames" 
# at the end to get the result in the right format:

# ```sql
# SELECT COUNT(DISTINCT u.user_id)
# FROM ecommerce.sessions s
# JOIN ecommerce.users u ON s.user_id = u.user_id
# WHERE 
#     EXTRACT(YEAR FROM s.action_date) = 2024
#     AND EXTRACT(MONTH FROM s.action_date) = 12
#     AND revenue > 0;
# format TabSeparatedWithNames;
# ```

The agent produced a fairly decent result, but there’s one issue – the LLM returned not only the SQL query but also some commentary. Since we plan to execute SQL queries later, this format is not suitable for our task. Let’s work on fixing it.

Fortunately, this problem has already been solved, and we don’t need to parse the SQL queries from the text manually. We can use the chat model ChatOllama. Unfortunately, it doesn’t support structured output, but we can leverage tool calling to achieve the same result.

To do this, we will define a dummy tool to execute the query and instruct the model in the system prompt always to call this tool. I’ve kept the comments in the output to give the model some space for reasoning, following the chain-of-thought pattern.

from langchain_ollama import ChatOllama
from langchain_core.tools import tool

@tool
def execute_query(comments: str, query: str) -> str:
  """Excutes SQL query.

  Args:
      comments (str): 1-2 sentences describing the result SQL query 
          and what it does to answer the question,
      query (str): SQL query
  """
  pass 

chat_llm = ChatOllama(model="llama3.1:8b").bind_tools([execute_query])
result = chat_llm.invoke(prompt)
print(result.tool_calls)

# [{'name': 'execute_query',
#   'args': {'comments': 'SQL query returns number of customers who made a purchase in December 2024. The query joins the sessions and users tables based on user ID to filter out inactive customers and find those with non-zero revenue in December 2024.',
#   'query': 'SELECT COUNT(DISTINCT T2.user_id) FROM ecommerce.sessions AS T1 INNER JOIN ecommerce.users AS T2 ON T1.user_id = T2.user_id WHERE YEAR(T1.action_date) = 2024 AND MONTH(T1.action_date) = 12 AND T2.is_active = 1 AND T1.revenue > 0'},
#   'type': 'tool_call'}]

With the tool calling, we can now get the SQL query directly from the model. That’s an excellent result. However, the generated query is not entirely accurate:

  • It includes a filter for is_active = 1, even though we didn’t specify the need to filter out inactive customers.
  • The LLM missed specifying the format despite our explicit request in the system prompt.

Clearly, we need to focus on improving the model’s accuracy. But as Peter Drucker famously said, "You can’t improve what you don’t measure." So, the next logical step is to build a system for evaluating the model’s quality. This system will be a cornerstone for performance improvement iterations. Without it, we’d essentially be navigating in the dark.

Evaluating the accuracy

Evaluation basics

To ensure we’re improving, we need a robust way to measure accuracy. The most common approach is to create a "golden" evaluation set with questions and correct answers. Then, we can compare the model’s output with these "golden" answers and calculate the share of correct ones. While this approach sounds simple, there are a few nuances worth discussing.

First, you might feel overwhelmed at the thought of creating a comprehensive set of questions and answers. Building such a dataset can seem like a daunting task, potentially requiring weeks or months. However, we can start small by creating an initial set of 20–50 examples and iterating on it.

As always, quality is more important than quantity. Our goal is to create a representative and diverse dataset. Ideally, this should include:

  • Common questions. In most real-life cases, we can take the history of actual questions and use it as our initial evaluation set.
  • Challenging edge cases. It’s worth adding examples where the model tends to hallucinate. You can find such cases either while experimenting yourself or by gathering feedback from the first prototype.

Once the dataset is ready, the next challenge is how to score the generated results. We can consider several approaches:

  • Comparing SQL queries. The first idea is to compare the generated SQL query with the one in the evaluation set. However, it might be tricky. Similarly-looking queries can yield completely different results. At the same time, queries that look different can lead to the same conclusions. Additionally, simply comparing SQL queries doesn’t verify whether the generated query is actually executable. Given these challenges, I wouldn’t consider this approach the most reliable solution for our case.
  • Exact matches. We can use old-school exact matching when answers in our evaluation set are deterministic. For example, if the question is, "How many customers are there?" and the answer is "592800", the model’s response must match precisely. However, this approach has its limitations. Consider the example above, and the model responds, "There are 592,800 customers". While the answer is absolutely correct, an exact match approach would flag it as invalid.
  • Using LLMs for scoring. A more robust and flexible approach is to leverage LLMs for evaluation. Instead of focusing on query structure, we can ask the LLM to compare the results of SQL executions. This method is particularly effective in cases where the query might differ but still yields correct outputs.

It’s worth keeping in mind that evaluation isn’t a one-time task; it’s a continuous process. To push our model’s performance further, we need to expand the dataset with examples causing the model’s hallucinations. In production mode, we can create a feedback loop. By gathering input from users, we can identify cases where the model fails and include them in our evaluation set.

In our example, we will be assessing only whether the result of execution is valid (SQL query can be executed) and correct. Still, you can look at other parameters as well. For example, if you care about efficiency, you can compare the execution times of generated queries against those in the golden set.

Evaluation set and validation

Now that we’ve covered the basics, we’re ready to put them into practice. I spent about 20 minutes putting together a set of 10 examples. While small, this set is sufficient for our toy task. It consists of a list of questions paired with their corresponding SQL queries, like this:

[
  {
    "question": "How many customers made purchase in December 2024?",
    "sql_query": "select uniqExact(user_id) as customers from ecommerce.sessions where (toStartOfMonth(action_date) = '2024-12-01') and (revenue > 0) format TabSeparatedWithNames"
  },
  {
    "question": "What was the fraud rate in 2023, expressed as a percentage?",
    "sql_query": "select 100*uniqExactIf(user_id, is_fraud = 1)/uniqExact(user_id) as fraud_rate from ecommerce.sessions where (toStartOfYear(action_date) = '2023-01-01') format TabSeparatedWithNames"
  },
  ...
]

You can find the full list on GitHub – link.

We can load the dataset into a DataFrame, making it ready for use in the code.

import json
with open('golden_set.json', 'r') as f:
  golden_set = json.loads(f.read())

golden_df = pd.DataFrame(golden_set) 
golden_df['id'] = list(range(golden_df.shape[0]))

First, let’s generate the SQL queries for each question in the evaluation set.

def generate_query(question):
  prompt = get_llama_prompt(question, generate_query_system_prompt)
  result = chat_llm.invoke(prompt)
  try:
    generated_query = result.tool_calls[0]['args']['query']
  except:
    generated_query = ''
  return generated_query

import tqdm

tmp = []
for rec in tqdm.tqdm(golden_df.to_dict('records')):
  generated_query = generate_query(rec['question'])
  tmp.append(
    {
      'id': rec['id'],
      'generated_query': generated_query
    }
  )

eval_df = golden_df.merge(pd.DataFrame(tmp))

Before moving on to the LLM-based scoring of query outputs, it’s important to first ensure that the SQL query is valid. To do this, we need to execute the queries and examine the database output.

I’ve created a function that runs a query in ClickHouse. It also ensures that the output format is correctly specified, as this may be critical in business applications.

CH_HOST = 'http://localhost:8123' # default address 
import requests
import io

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  # pushing model to return data in the format that we want
  if not 'format tabseparatedwithnames' in query.lower():
    return "Database returned the following error:n Please, specify the output format."

  r = requests.post(host, params = {'query': query}, 
    timeout = connection_timeout)
  if r.status_code == 200:
    return r.text
  else: 
    return 'Database returned the following error:n' + r.text
    # giving feedback to LLM instead of raising exception

The next step is to execute both the generated and golden queries and then save their outputs.

tmp = []

for rec in tqdm.tqdm(eval_df.to_dict('records')):
  golden_output = get_clickhouse_data(rec['sql_query'])
  generated_output = get_clickhouse_data(rec['generated_query'])

  tmp.append(
    {
      'id': rec['id'],
      'golden_output': golden_output,
      'generated_output': generated_output
    }
  )

eval_df = eval_df.merge(pd.DataFrame(tmp))

Next, let’s check the output to see whether the SQL query is valid or not.

def is_valid_output(s):
  if s.startswith('Database returned the following error:'):
    return 'error'
  if len(s.strip().split('n')) >= 1000:
    return 'too many rows'
  return 'ok'

eval_df['golden_output_valid'] = eval_df.golden_output.map(is_valid_output)
eval_df['generated_output_valid'] = eval_df.generated_output.map(is_valid_output)

Then, we can evaluate the SQL validity for both the golden and generated sets.

The initial results are not very promising; the LLM was unable to generate even a single valid query. Looking at the errors, it’s clear that the model failed to specify the right format despite it being explicitly defined in the system prompt. So, we definitely need to work more on the accuracy.

Checking the correctness

However, validity alone is not enough. It’s crucial that we not only generate valid SQL queries but also produce the correct results. Although we already know that all our queries are invalid, let’s now incorporate output evaluation into our process.

As discussed, we will use LLMs to compare the outputs of the SQL queries. I typically prefer using more powerful model for evaluation, following the day-to-day logic where a senior team member reviews the work. For this task, I’ve chosen OpenAI GPT 4o-mini.

Similar to our generation flow, I’ve set up all the building blocks necessary for accuracy assessment.

from langchain_openai import ChatOpenAI

accuracy_system_prompt = '''
You are a senior and very diligent QA specialist and your task is to compare data in datasets. 
They are similar if they are almost identical, or if they convey the same information. 
Disregard if column names specified in the first row have different names or in a different order.
Focus on comparing the actual information (numbers). If values in datasets are different, then it means that they are not identical.
Always execute tool to provide results.
'''

@tool
def compare_datasets(comments: str, score: int) -> str:
  """Stores info about datasets.
  Args:
      comments (str): 1-2 sentences about the comparison of datasets,
      score (int): 0 if dataset provides different values and 1 if it shows identical information
  """
  pass

accuracy_chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature = 0.0)
  .bind_tools([compare_datasets])

accuracy_question_tmp = '''
Here are the two datasets to compare delimited by ####
Dataset #1: 
####
{dataset1}
####
Dataset #2: 
####
{dataset2}
####
'''

def get_openai_prompt(question, system):
  messages = [
    ("system", system),
    ("human", question)
  ]
  return messages

Now, it’s time to test the accuracy assessment process.

prompt = get_openai_prompt(accuracy_question_tmp.format(
  dataset1 = 'customersn114032n', dataset2 = 'customersn114031n'),
  accuracy_system_prompt)

accuracy_result = accuracy_chat_llm.invoke(prompt)
accuracy_result.tool_calls[0]['args']
# {'comments': 'The datasets contain different customer counts: 114032 in Dataset #1 and 114031 in Dataset #2.',
#  'score': 0}

prompt = get_openai_prompt(accuracy_question_tmp.format(
  dataset1 = 'usersn114032n', dataset2 = 'customersn114032n'),
  accuracy_system_prompt)
accuracy_result = accuracy_chat_llm.invoke(prompt)
accuracy_result.tool_calls[0]['args']
# {'comments': 'The datasets contain the same numerical value (114032) despite different column names, indicating they convey identical information.',
#  'score': 1}

Fantastic! It looks like everything is working as expected. Let’s now encapsulate this into a function.

def is_answer_accurate(output1, output2):
  prompt = get_openai_prompt(
    accuracy_question_tmp.format(dataset1 = output1, dataset2 = output2),
    accuracy_system_prompt
  )

  accuracy_result = accuracy_chat_llm.invoke(prompt)

  try:
    return accuracy_result.tool_calls[0]['args']['score']
  except:
    return None

Putting the evaluation approach together

As we discussed, building an LLM application is an iterative process, so we’ll need to run our accuracy assessment multiple times. It will be helpful to have all this logic encapsulated in a single function.

The function will take two arguments as input:

  • generate_query_func: a function that generates an SQL query for a given question.
  • golden_df: an evaluation dataset with questions and correct answers in the form of a pandas DataFrame.

As output, the function will return a DataFrame with all evaluation results and a couple of charts displaying the main KPIs.


def evaluate_sql_agent(generate_query_func, golden_df):

  # generating SQL
  tmp = []
  for rec in tqdm.tqdm(golden_df.to_dict('records')):
    generated_query = generate_query_func(rec['question'])
    tmp.append(
      {
          'id': rec['id'],
          'generated_query': generated_query
      }
    )

  eval_df = golden_df.merge(pd.DataFrame(tmp))

  # executing SQL queries
  tmp = []
  for rec in tqdm.tqdm(eval_df.to_dict('records')):
    golden_output = get_clickhouse_data(rec['sql_query'])
    generated_output = get_clickhouse_data(rec['generated_query'])

    tmp.append(
      {
        'id': rec['id'],
        'golden_output': golden_output,
        'generated_output': generated_output
      }
    )

  eval_df = eval_df.merge(pd.DataFrame(tmp))

  # checking accuracy
  eval_df['golden_output_valid'] = eval_df.golden_output.map(is_valid_output)
  eval_df['generated_output_valid'] = eval_df.generated_output.map(is_valid_output)

  eval_df['correct_output'] = list(map(
    is_answer_accurate,
    eval_df['golden_output'],
    eval_df['generated_output']
  ))

  eval_df['accuracy'] = list(map(
    lambda x, y: 'invalid: ' + x if x != 'ok' else ('correct' if y == 1 else 'incorrect'),
    eval_df.generated_output_valid,
    eval_df.correct_output
  ))

  valid_stats_df = (eval_df.groupby('golden_output_valid')[['id']].count().rename(columns = {'id': 'golden set'}).join(
    eval_df.groupby('generated_output_valid')[['id']].count().rename(columns = {'id': 'generated'}), how = 'outer')).fillna(0).T

  fig1 = px.bar(
    valid_stats_df.apply(lambda x: 100*x/valid_stats_df.sum(axis = 1)),
    orientation = 'h', 
    title = '<b>LLM SQL Agent evaluation</b>: query validity',
    text_auto = '.1f',
    color_discrete_map = {'ok': '#00b38a', 'error': '#ea324c', 'too many rows': '#f2ac42'},
    labels = {'index': '', 'variable': 'validity', 'value': 'share of queries, %'}
  )
  fig1.show()

  accuracy_stats_df = eval_df.groupby('accuracy')[['id']].count()
  accuracy_stats_df['share'] = accuracy_stats_df.id*100/accuracy_stats_df.id.sum()

  fig2 = px.bar(
    accuracy_stats_df[['share']],
    title = '<b>LLM SQL Agent evaluation</b>: query accuracy',
    text_auto = '.1f', orientation = 'h',
    color_discrete_sequence = ['#0077B5'],
    labels = {'index': '', 'variable': 'accuracy', 'value': 'share of queries, %'}
  )

  fig2.update_layout(showlegend = False)
  fig2.show()

  return eval_df

With that, we’ve completed the evaluation setup and can now move on to the core task of improving the model’s accuracy.

Improving accuracy: Self-reflection

Let’s do a quick recap. We’ve built and tested the first version of SQL Agent. Unfortunately, all generated queries were invalid because they were missing the output format. Let’s address this issue.

One potential solution is self-reflection. We can make an additional call to the LLM, sharing the error and asking it to correct the bug. Let’s create a function to handle generation with self-reflection.

reflection_user_query_tmpl = '''
You've got the following question: "{question}". 
You've generated the SQL query: "{query}".
However, the database returned an error: "{output}". 
Please, revise the query to correct mistake. 
'''

def generate_query_reflection(question):
  generated_query = generate_query(question) 
  print('Initial query:', generated_query)

  db_output = get_clickhouse_data(generated_query)
  is_valid_db_output = is_valid_output(db_output)
  if is_valid_db_output == 'too many rows':
    db_output = "Database unexpectedly returned more than 1000 rows."

  if is_valid_db_output == 'ok': 
    return generated_query

  reflection_user_query = reflection_user_query_tmpl.format(
    question = question,
    query = generated_query,
    output = db_output
  )

  reflection_prompt = get_llama_prompt(reflection_user_query, 
    generate_query_system_prompt) 
  reflection_result = chat_llm.invoke(reflection_prompt)

  try:
    reflected_query = reflection_result.tool_calls[0]['args']['query']
  except:
    reflected_query = ''
  print('Reflected query:', reflected_query)
  return reflected_query

Now, let’s use our evaluation function to check whether the quality has improved. Assessing the next iteration has become effortless.

refl_eval_df = evaluate_sql_agent(generate_query_reflection, golden_df)

Wonderful! We’ve achieved better results – 50% of the queries are now valid, and all format issues have been resolved. So, self-reflection is pretty effective.

However, self-reflection has its limitations. When we examine the accuracy, we see that the model returns the correct answer for only one question. So, our journey is not over yet.

Improving accuracy: RAG

Another approach to improving accuracy is using RAG (retrieval-augmented generation). The idea is to identify question-and-answer pairs similar to the customer query and include them in the system prompt, enabling the LLM to generate a more accurate response.

RAG consists of the following stages:

  • Loading documents: importing data from available sources.
  • Splitting documents: creating smaller chunks.
  • Storage: using vector stores to process and store data efficiently.
  • Retrieval: extracting documents that are relevant to the query.
  • Generation: passing a question and relevant documents to LLM to generate the final answer.

If you’d like a refresher on RAG, you can check out my previous article, "RAG: How to Talk to Your Data."

We will use the Chroma database as a local vector storage – to store and retrieve embeddings.

from langchain_chroma import Chroma
vector_store = Chroma(embedding_function=embeddings)

Vector stores are using embeddings to find chunks that are similar to the query. For this purpose, we will use OpenAI embeddings.

from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

Since we can’t use examples from our evaluation set (as they are already being used to assess quality), I’ve created a separate set of question-and-answer pairs for RAG. You can find it on GitHub.

Now, let’s load the set and create a list of pairs in the following format: Question: %s; Answer: %s.

with open('rag_set.json', 'r') as f:
    rag_set = json.loads(f.read())
rag_set_df = pd.DataFrame(rag_set)

rag_set_df['formatted_txt'] = list(map(
    lambda x, y: 'Question: %s; Answer: %s' % (x, y),
    rag_set_df.question,
    rag_set_df.sql_query
))

rag_string_data = 'nn'.join(rag_set_df.formatted_txt)

Next, I used LangChain’s text splitter by character to create chunks, with each question-and-answer pair as a separate chunk. Since we are splitting the text semantically, no overlap is necessary.

from langchain_text_splitters import CharacterTextSplitter

text_splitter = CharacterTextSplitter(
    separator="nn",
    chunk_size=1, # to split by character without merging
    chunk_overlap=0,
    length_function=len,
    is_separator_regex=False,
)

texts = text_splitter.create_documents([rag_string_data])

The final step is to load the chunks into our vector storage.

document_ids = vector_store.add_documents(documents=texts)
print(vector_store._collection.count())
# 32

Now, we can test the retrieval to see the results. They look quite similar to the customer question.

question = 'What was the share of users using Windows yesterday?'
retrieved_docs = vector_store.similarity_search(question, 3)
context = "nn".join(map(lambda x: x.page_content, retrieved_docs))
print(context)

# Question: What was the share of users using Windows the day before yesterday?; 
# Answer: select 100*uniqExactIf(user_id, os = 'Windows')/uniqExact(user_id) as windows_share from ecommerce.sessions where (action_date = today() - 2) format TabSeparatedWithNames
# Question: What was the share of users using Windows in the last week?; 
# Answer: select 100*uniqExactIf(user_id, os = 'Windows')/uniqExact(user_id) as windows_share from ecommerce.sessions where (action_date >= today() - 7) and (action_date < today()) format TabSeparatedWithNames
# Question: What was the share of users using Android yesterday?; 
# Answer: select 100*uniqExactIf(user_id, os = 'Android')/uniqExact(user_id) as android_share from ecommerce.sessions where (action_date = today() - 1) format TabSeparatedWithNames

Let’s adjust the system prompt to include the examples we retrieved.

generate_query_system_prompt_with_examples_tmpl = '''
You are a senior data analyst with more than 10 years of experience writing complex SQL queries. 
There are two tables in the database you're working with with the following schemas. 

Table: ecommerce.users 
Description: customers of the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- country (string) - country of residence, for example, "Netherlands" or "United Kingdom"
- is_active (integer) - 1 if customer is still active and 0 otherwise
- age (integer) - customer age in full years, for example, 31 or 72

Table: ecommerce.sessions 
Description: sessions of usage the online shop
Fields: 
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- session_id (integer) - unique identifier of session, for example, 106 or 1023
- action_date (date) - session start date, for example, "2021-01-03" or "2024-12-02"
- session_duration (integer) - duration of session in seconds, for example, 125 or 49
- os (string) - operation system that customer used, for example, "Windows" or "Android"
- browser (string) - browser that customer used, for example, "Chrome" or "Safari"
- is_fraud (integer) - 1 if session is marked as fraud and 0 otherwise
- revenue (float) - income in USD (the sum of purchased items), for example, 0.0 or 1506.7

Write a query in ClickHouse SQL to answer the following question. 
Add "format TabSeparatedWithNames" at the end of the query to get data from ClickHouse database in the right format. 
Answer questions following the instructions and providing all the needed information and sharing your reasoning. 

Examples of questions and answers: 
{examples}
'''

Once again, let’s create the generate query function with RAG.

def generate_query_rag(question):
  retrieved_docs = vector_store.similarity_search(question, 3)
  context = context = "nn".join(map(lambda x: x.page_content, retrieved_docs))

  prompt = get_llama_prompt(question, 
    generate_query_system_prompt_with_examples_tmpl.format(examples = context))
  result = chat_llm.invoke(prompt)

  try:
    generated_query = result.tool_calls[0]['args']['query']
  except:
    generated_query = ''
  return generated_query

As usual, let’s use our evaluation function to test the new approach.

rag_eval_df = evaluate_sql_agent(generate_query_rag, golden_df)

We can see a significant improvement, increasing from 1 to 6 correct answers out of 10. It’s still not ideal, but we’re moving in the right direction.

We can also experiment with combining two approaches: RAG and self-reflection.

def generate_query_rag_with_reflection(question):
  generated_query = generate_query_rag(question) 

  db_output = get_clickhouse_data(generated_query)
  is_valid_db_output = is_valid_output(db_output)
  if is_valid_db_output == 'too many rows':
      db_output = "Database unexpectedly returned more than 1000 rows."

  if is_valid_db_output == 'ok': 
      return generated_query

  reflection_user_query = reflection_user_query_tmpl.format(
    question = question,
    query = generated_query,
    output = db_output
  )

  reflection_prompt = get_llama_prompt(reflection_user_query, generate_query_system_prompt) 
  reflection_result = chat_llm.invoke(reflection_prompt)

  try:
    reflected_query = reflection_result.tool_calls[0]['args']['query']
  except:
    reflected_query = ''
  return reflected_query

rag_refl_eval_df = evaluate_sql_agent(generate_query_rag_with_reflection, 
  golden_df)

We can see another slight improvement: we’ve completely eliminated invalid SQL queries (thanks to self-reflection) and increased the number of correct answers to 7 out of 10.

That’s it. It’s been quite a journey. We started with 0 valid SQL queries and have now achieved 70% accuracy.

You can find the complete code on GitHub.

Summary

In this article, we explored the iterative process of improving accuracy for LLM applications.

  • We built an evaluation set and the scoring criteria that allowed us to compare different iterations and understand whether we were moving in the right direction.
  • We leveraged self-reflection to allow the LLM to correct its mistakes and significantly reduce the number of invalid SQL queries.
  • Additionally, we implemented Retrieval-Augmented Generation (RAG) to further enhance the quality, achieving an accuracy rate of 60–70%.

While this is a solid result, it still falls short of the 90%+ accuracy threshold typically expected for production applications. To achieve such a high bar, we need to use fine-tuning, which will be the topic of the next article.

Thank you a lot for reading this article. I hope this article was insightful for you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

This article is inspired by the "Improving Accuracy of LLM Applications" short course from DeepLearning.AI.

The post From Prototype to Production: Enhancing LLM Accuracy appeared first on Towards Data Science.

]]>
Linear Optimisations in Product Analytics https://towardsdatascience.com/linear-optimisations-in-product-analytics-ace19e925677/ Wed, 18 Dec 2024 11:01:52 +0000 https://towardsdatascience.com/linear-optimisations-in-product-analytics-ace19e925677/ Solving the knapsack problem

The post Linear Optimisations in Product Analytics appeared first on Towards Data Science.

]]>
It might be surprising, but in this article, I would like to talk about the knapsack problem, the classic optimisation problem that has been studied for over a century. According to Wikipedia, the problem is defined as follows:

Given a set of items, each with a weight and a value, determine which items to include in the collection so that the total weight is less than or equal to a given limit and the total value is as large as possible.

While product analysts may not physically pack knapsacks, the underlying mathematical model is highly relevant to many of our tasks. There are numerous real-world applications of the knapsack problem in product Analytics. Here are a few examples:

  • Marketing Campaigns: The marketing team has a limited budget and capacity to run campaigns across different channels and regions. Their goal is to maximize a KPI, such as the number of new users or revenue, all while adhering to existing constraints.
  • Retail Space Optimization: A retailer with limited physical space in their stores seeks to optimize product placement to maximize revenue.
  • Product Launch Prioritization: When launching a new product, the operations team’s capacity might be limited, requiring prioritization of specific markets.

Such and similar tasks are quite common, and many analysts encounter them regularly. So, in this article, I’ll explore different approaches to solving it, ranging from naive, simple techniques to more advanced methods such as Linear Programming.

Another reason I chose this topic is that linear programming is one of the most powerful and popular tools in prescriptive analytics – a type of analysis that focuses on providing stakeholders with actionable options to make informed decisions. As such, I believe it is an essential skill for any analyst to have in their toolkit.

Case

Let’s dive straight into the case we’ll be exploring. Imagine we’re part of a marketing team planning activities for the upcoming month. Our objective is to maximize key performance indicators (KPIs), such as the number of acquired users and revenue while operating within a limited marketing budget.

We’ve estimated the expected outcomes of various marketing activities across different countries and channels. Here is the data we have:

  • country – the market where we can do some promotional activities;
  • channel – the acquisition method, such as social networks or influencer campaigns;
  • users – the expected number of users acquired within a month of the promo campaign;
  • cs_contacts – the incremental Customer Support contacts generated by the new users;
  • marketing_spending – the investment required for the activity;
  • revenue – the first-year LTV generated from acquired customers.

Note that the dataset is synthetic and randomly generated, so don’t try to infer any market-related insights from it.

First, I’ve calculated the high-level statistics to get a view of the numbers.

Let’s determine the optimal set of marketing activities that maximizes revenue while staying within the $30M marketing budget.

Brute force approach

At first glance, the problem may seem straightforward: we could calculate all possible combinations of marketing activities and select the optimal one. However, it might be a challenging task.

With 62 segments, there are 2⁶² possible combinations, as each segment can either be included or excluded. This results in approximately 4.6*10¹⁸ combinations – an astronomical number.

To better understand the computational feasibility, let’s consider a smaller subset of 15 segments and estimate the time required for one iteration.

import itertools
import pandas as pd
import tqdm

# reading data
df = pd.read_csv('marketing_campaign_estimations.csv', sep = 't')
df['segment'] = df.country + ' - ' + df.channel

# calculating combinations
combinations = []
segments = list(df.segment.values)[:15]
print('number of segments: ', len(segments))

for num_items in range(len(segments) + 1):
  combinations.extend(
      itertools.combinations(segments, num_items)
  )
print('number of combinations: ', len(combinations))

tmp = []
for selected in tqdm.tqdm(combinations):
    tmp_df = df[df.segment.isin(selected)]
    tmp.append(
        {
        'selected_segments': ', '.join(selected),
        'users': tmp_df.users.sum(),
        'cs_contacts': tmp_df.cs_contacts.sum(),
        'marketing_spending': tmp_df.marketing_spending.sum(),
        'revenue': tmp_df.revenue.sum()
        }
    )

# number of segments:  15
# number of combinations:  32768

It took approximately 4 seconds to process 15 segments, allowing us to handle around 7,000 iterations per second. Using this estimate, let’s calculate the execution time for the full set of 62 segments.

2**62 / 7000 / 3600 / 24 / 365
# 20 890 800.6

Using brute force, it would take around 20.9 million years to get the answer to our question – clearly not a feasible option.

Execution time is entirely determined by the number of segments. Removing just one segment can reduce time twice. With this in mind, let’s explore possible ways to merge segments.

As usual, there are more small-sized segments than bigger ones, so merging them is a logical step. However, it’s important to note that this approach may reduce accuracy since multiple segments are aggregated into one. Despite this, it could still yield a solution that is "good enough."

To simplify, let’s merge all segments that contribute less than 0.1% of revenue.

df['share_of_revenue'] = df.revenue/df.revenue.sum() * 100
df['segment_group'] = list(map(
    lambda x, y: x if y >= 0.1 else 'other',
    df.segment,
    df.share_of_revenue
))

print(df[df.segment_group == 'other'].share_of_revenue.sum())
# 0.53
print(df.segment_group.nunique())
# 52

With this approach, we will merge ten segments into one, representing 0.53% of the total revenue (the potential margin of error). With 52 segments remaining, we can obtain the solution in just 20.4K years. While this is a significant improvement, it’s still not sufficient.

You may consider other heuristics tailored to your specific task. For instance, if your constraint is a ratio (e.g., contact rate = CS contacts / users ≤ 5%), you could group all segments where the constraint holds true, as the optimal solution will include all of them. In our case, however, I don’t see any additional strategies to reduce the number of segments, so brute force seems impractical.

That said, if the number of combinations is relatively small and brute force can be executed within a reasonable time, it can be an ideal approach. It’s simple to develop and provides accurate results.

Naive approach: looking at top-performing segments

Since brute force is not feasible for calculating all combinations, let’s consider a simpler algorithm to address this problem.

One possible approach is to focus on the top-performing segments. We can evaluate segment performance by calculating revenue per dollar spent, then sort all activities based on this ratio and select the top performers that fit within the marketing budget. Let’s implement it.

df['revenue_per_spend'] = df.revenue / df.marketing_spending 
df = df.sort_values('revenue_per_spend', ascending = False)
df['spend_cumulative'] = df.marketing_spending.cumsum()
selected_df = df[df.spend_cumulative <= 30000000]
print(selected_df.shape[0])
# 48 
print(selected_df.revenue.sum()/1000000)
# 107.92

With this approach, we selected 48 activities and got $107.92M in revenue.

Unfortunately, although the logic seems reasonable, it is not the optimal solution for maximizing revenue. Let’s look at a simple example with just three marketing activities.

Using the top markets approach, we would select France and achieve $68M in revenue. However, by choosing two other markets, we could achieve significantly better results – $97.5M. The key point is that our algorithm optimizes not only for maximum revenue but also for minimizing the number of selected segments. Therefore, this approach will not yield the best results, especially considering its inability to account for multiple constraints.

Linear Programming

Since all simple approaches have failed, we must return to the fundamentals and explore the theory behind this problem. Fortunately, the knapsack problem has been studied for many years, and we can apply Optimization techniques to solve it in seconds rather than years.

The problem we’re trying to solve is an example of Integer Programming, which is actually a subdomain of Linear Programming.

We’ll discuss this shortly, but first, let’s align on the key concepts of the optimization process. Each optimisation problem consists of:

  • Decision variables: Parameters that can be adjusted in the model, typically representing the levers or decisions we want to make.
  • Objective function: The target variable we aim to maximize or minimize. It goes without saying that it must depend on the decision variables.
  • Constraints: Conditions placed on the decision variables that define their possible values. For example, ensuring the team cannot work a negative number of hours.

With these basic concepts in mind, we can define Linear Programming as a scenario where the following conditions hold:

  • The objective function is linear.
  • All constraints are linear.
  • Decision variables are real-valued.

Integer Programming is very similar to Linear Programming, with one key difference: some or all decision variables must be integers. While this may seem like a minor change, it significantly impacts the solution approach, requiring more complex methods than those used in Linear Programming. One common technique is branch-and-bound. We won’t dive deeper into the theory here, but you can always find more detailed explanations online.

For linear optimization, I prefer the widely used Python package PuLP. However, there are other options available, such as Python MIP or Pyomo. Let’s install PuLP via pip.

! pip install pulp

Now, it’s time to define our task as a mathematical optimisation problem. There are the following steps for it:

  • Define the set of decision variables (levers we can adjust).
  • Align on the objective function (a variable that we will be optimising for).
  • Formulate constraints (the conditions that must hold true during optimisations).

Let’s go through the steps one by one. But first, we need to create the problem object and set the objective – maximization in our case.

from pulp import *
problem = LpProblem("Marketing_campaign", LpMaximize)

The next step is defining the decision variables – parameters that we can change during optimisation. Our main decision is either to run a marketing campaign or not. So, we can model it as a set of binary variables (0 or 1) for each segment. Let’s do it with the PuLP library.

segments = range(df.shape[0])  
selected = LpVariable.dicts("Selected", segments, cat="Binary")

After that, it’s time to align on the objective function. As discussed, we want to maximise the revenue. The total revenue will be a sum of revenue from all the selected segments (where decision_variable = 1 ). Therefore, we can define this formula as the sum of the expected revenue for each segment multiplied by the decision binary variable.

problem += lpSum(
  selected[i] * list(df['revenue'].values)[i] 
  for i in segments
)

The final step is to add constraints. Let’s start with a simple constraint: our marketing spending must be below $30M.

problem += lpSum(
    selected[i] * df['marketing_spending'].values[i]
    for i in segments
) <= 30 * 10**6

Hint: you can print problem to double check the objective function and constraints.

Now that we’ve defined everything, we can run the optimization and analyze the results.

problem.solve()

It takes less than a second to run the optimization, a significant improvement compared to the thousands of years that brute force would require.

Result - Optimal solution found

Objective value:                110162662.21000001
Enumerated nodes:               4
Total iterations:               76
Time (CPU seconds):             0.02
Time (Wallclock seconds):       0.02

Let’s save the results of the model execution – the decision variables indicating whether each segment was selected or not – into our dataframe.

df['selected'] = list(map(lambda x: x.value(), selected.values()))
print(df[df.selected == 1].revenue.sum()/10**6)
# 110.16

It works like magic, allowing you to obtain the solution quickly. Additionally, note that we achieved higher revenue compared to our naive approach: $110.16M versus $107.92M.

We’ve tested integer programming with a simple example featuring just one constraint, but we can extend it further. For instance, we can add additional constraints for our CS contacts to ensure that our Operations team can handle the demand in a healthy way:

  • The number of additional CS contacts ≤ 5K
  • Contact rate (CS contacts/users) ≤ 0.042
# define the problem
problem_v2 = LpProblem("Marketing_campaign_v2", LpMaximize)

# decision variables
segments = range(df.shape[0]) 
selected = LpVariable.dicts("Selected", segments, cat="Binary")

# objective function
problem_v2 += lpSum(
  selected[i] * list(df['revenue'].values)[i] 
  for i in segments
)

# Constraints
problem_v2 += lpSum(
    selected[i] * df['marketing_spending'].values[i]
    for i in segments
) <= 30 * 10**6

problem_v2 += lpSum(
    selected[i] * df['cs_contacts'].values[i]
    for i in segments
) <= 5000

problem_v2 += lpSum(
    selected[i] * df['cs_contacts'].values[i]
    for i in segments
) <= 0.042 * lpSum(
    selected[i] * df['users'].values[i]
    for i in segments
)

# run the optimisation
problem_v2.solve()

The code is straightforward, with the only tricky part being the transformation of the ratio constraint into a simpler linear form.

Another potential constraint you might consider is limiting the number of selected options, for example, to 10. This constraint could be pretty helpful in prescriptive analytics, for example, when you need to select the top-N most impactful focus areas.

# define the problem
problem_v3 = LpProblem("Marketing_campaign_v2", LpMaximize)

# decision variables
segments = range(df.shape[0]) 
selected = LpVariable.dicts("Selected", segments, cat="Binary")

# objective function
problem_v3 += lpSum(
  selected[i] * list(df['revenue'].values)[i] 
  for i in segments
)

# constraints
problem_v3 += lpSum(
    selected[i] * df['marketing_spending'].values[i]
    for i in segments
) <= 30 * 10**6

problem_v3 += lpSum(
    selected[i] for i in segments
) <= 10

# run the optimisation
problem_v3.solve()
df['selected'] = list(map(lambda x: x.value(), selected.values()))
print(df.selected.sum())
# 10

Another possible option to tweak our problem is to change the objective function. We’ve been optimising for the revenue, but imagine we want to maximise both revenue and new users at the same time. For that, we can slightly change our objective function.

Let’s consider the best approach. We could calculate the sum of revenue and new users and aim to maximize it. However, since revenue is, on average, 1000 times higher, the results might be skewed toward maximizing revenue. To make the metrics more comparable, we can normalize both revenue and users based on their total sums. Then, we can define the objective function as a weighted sum of these ratios. I would use equal weights (0.5) for both metrics, but you can adjust the weights to give more value to one of them.

# define the problem
problem_v4 = LpProblem("Marketing_campaign_v2", LpMaximize)

# decision variables
segments = range(df.shape[0]) 
selected = LpVariable.dicts("Selected", segments, cat="Binary")

# objective Function
problem_v4 += (
    0.5 * lpSum(
        selected[i] * df['revenue'].values[i] / df['revenue'].sum()
        for i in segments
    )
    + 0.5 * lpSum(
        selected[i] * df['users'].values[i] / df['users'].sum()
        for i in segments
    )
)

# constraints
problem_v4 += lpSum(
    selected[i] * df['marketing_spending'].values[i]
    for i in segments
) <= 30 * 10**6

# run the optimisation
problem_v4.solve()
df['selected'] = list(map(lambda x: x.value(), selected.values()))

We obtained the optimal objective function value of 0.6131, with revenue at $104.36M and 136.37K new users.

That’s it! We’ve learned how to use integer programming to solve various optimisation problems.

You can find the full code on GitHub.

Summary

In this article, we explored different methods for solving the knapsack problem and its analogues in product analytics.

  • We began with a brute-force approach but quickly realized it would take an unreasonable amount of time.
  • Next, we tried using common sense by naively selecting the top-performing segments, but this approach yielded incorrect results.
  • Finally, we turned to Integer Programming, learning how to translate our product tasks into optimization models and solve them effectively.

With this, I hope you’ve gained another valuable analytical tool for your toolkit.

Thank you a lot for reading this article. I hope this article was insightful for you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

The post Linear Optimisations in Product Analytics appeared first on Towards Data Science.

]]>
From Basics to Advanced: Exploring LangGraph https://towardsdatascience.com/from-basics-to-advanced-exploring-langgraph-e8c1cf4db787/ Thu, 15 Aug 2024 18:36:10 +0000 https://towardsdatascience.com/from-basics-to-advanced-exploring-langgraph-e8c1cf4db787/ Building single- and multi-agent workflows with human-in-the-loop interactions

The post From Basics to Advanced: Exploring LangGraph appeared first on Towards Data Science.

]]>
Image by DALL-E 3
Image by DALL-E 3

LangChain is one of the leading frameworks for building applications powered by Lardge Language Models. With the LangChain Expression Language (LCEL), defining and executing step-by-step action sequences – also known as chains – becomes much simpler. In more technical terms, LangChain allows us to create DAGs (directed acyclic graphs).

As Llm applications, particularly LLM agents, have evolved, we’ve begun to use LLMs not just for execution but also as reasoning engines. This shift has introduced interactions that frequently involve repetition (cycles) and complex conditions. In such scenarios, LCEL is not sufficient, so LangChain implemented a new module – LangGraph.

LangGraph (as you might guess from the name) models all interactions as cyclical graphs. These graphs enable the development of advanced workflows and interactions with multiple loops and if-statements, making it a handy tool for creating both agent and multi-agent workflows.

In this article, I will explore LangGraph’s key features and capabilities, including multi-agent applications. We’ll build a system that can answer different types of questions and dive into how to implement a human-in-the-loop setup.

In the previous article, we tried using CrewAI, another popular framework for multi-agent systems. LangGraph, however, takes a different approach. While CrewAI is a high-level framework with many predefined features and ready-to-use components, LangGraph operates at a lower level, offering extensive customization and control.

With that introduction, let’s dive into the fundamental concepts of LangGraph.

LangGraph basics

LangGraph is part of the LangChain ecosystem, so we will continue using well-known concepts like prompt templates, tools, etc. However, LangGraph brings a bunch of additional concepts. Let’s discuss them.

LangGraph is created to define cyclical graphs. Graphs consist of the following elements:

  • Nodes represent actual actions and can be either LLMs, agents or functions. Also, a special END node marks the end of execution.
  • Edges connect nodes and determine the execution flow of your graph. There are basic edges that simply link one node to another and conditional edges that incorporate if-statements and additional logic.

Another important concept is the state of the graph. The state serves as a foundational element for collaboration among the graph’s components. It represents a snapshot of the graph that any part – whether nodes or edges – can access and modify during execution to retrieve or update information.

Additionally, the state plays a crucial role in persistence. It is automatically saved after each step, allowing you to pause and resume execution at any point. This feature supports the development of more complex applications, such as those requiring error correction or incorporating human-in-the-loop interactions.

Single-agent workflow

Building agent from scratch

Let’s start simple and try using LangGraph for a basic use case – an agent with tools.

I will try to build similar applications to those we did with CrewAI in the previous article. Then, we will be able to compare the two frameworks. For this example, let’s create an application that can automatically generate documentation based on the table in the database. It can save us quite a lot of time when creating documentation for our data sources.

As usual, we will start by defining the tools for our agent. Since I will use the ClickHouse database in this example, I’ve defined a function to execute any query. You can use a different database if you prefer, as we won’t rely on any database-specific features.

CH_HOST = 'http://localhost:8123' # default address 
import requests

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query}, 
    timeout = connection_timeout)
  if r.status_code == 200:
      return r.text
  else: 
      return 'Database returned the following error:n' + r.text

It’s crucial to make LLM tools reliable and error-prone. If a database returns an error, I provide this feedback to the LLM rather than throwing an exception and halting execution. Then, the LLM agent will have an opportunity to fix an error and call the function again.

Let’s define one tool named execute_sql , which enables the execution of any SQL query. We use pydantic to specify the tool’s structure, ensuring that the LLM agent has all the needed information to use the tool effectively.

from langchain_core.tools import tool
from pydantic.v1 import BaseModel, Field
from typing import Optional

class SQLQuery(BaseModel):
  query: str = Field(description="SQL query to execute")

@tool(args_schema = SQLQuery)
def execute_sql(query: str) -> str:
  """Returns the result of SQL query execution"""
  return get_clickhouse_data(query)

We can print the parameters of the created tool to see what information is passed to LLM.

print(f'''
name: {execute_sql.name}
description: {execute_sql.description}
arguments: {execute_sql.args}
''')

# name: execute_sql
# description: Returns the result of SQL query execution
# arguments: {'query': {'title': 'Query', 'description': 
#   'SQL query to execute', 'type': 'string'}}

Everything looks good. We’ve set up the necessary tool and can now move on to defining an LLM agent. As we discussed above, the cornerstone of the agent in LangGraph is its state, which enables the sharing of information between different parts of our graph.

Our current example is relatively straightforward. So, we will only need to store the history of messages. Let’s define the agent state.

# useful imports
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage

# defining agent state
class AgentState(TypedDict):
   messages: Annotated[list[AnyMessage], operator.add]

We’ve defined a single parameter in AgentStatemessages – which is a list of objects of the class AnyMessage . Additionally, we annotated it with operator.add (reducer). This annotation ensures that each time a node returns a message, it is appended to the existing list in the state. Without this operator, each new message would replace the previous value rather than being added to the list.

The next step is to define the agent itself. Let’s start with __init__ function. We will specify three arguments for the agent: model, list of tools and system prompt.

class SQLAgent:
  # initialising the object
  def __init__(self, model, tools, system_prompt = ""):
    self.system_prompt = system_prompt

    # initialising graph with a state 
    graph = StateGraph(AgentState)

    # adding nodes 
    graph.add_node("llm", self.call_llm)
    graph.add_node("function", self.execute_function)
    graph.add_conditional_edges(
        "llm",
        self.exists_function_calling,
        {True: "function", False: END}
    )
    graph.add_edge("function", "llm")

    # setting starting point
    graph.set_entry_point("llm")

    self.graph = graph.compile()
    self.tools = {t.name: t for t in tools}
    self.model = model.bind_tools(tools)

In the initialisation function, we’ve outlined the structure of our graph, which includes two nodes: llm and action. Nodes are actual actions, so we have functions associated with them. We will define functions a bit later.

Additionally, we have one conditional edge that determines whether we need to execute the function or generate the final answer. For this edge, we need to specify the previous node (in our case, llm), a function that decides the next step, and mapping of the subsequent steps based on the function’s output (formatted as a dictionary). If exists_function_calling returns True, we follow to the function node. Otherwise, execution will conclude at the special END node, which marks the end of the process.

We’ve added an edge between function and llm. It just links these two steps and will be executed without any conditions.

With the main structure defined, it’s time to create all the functions outlined above. The first one is call_llm. This function will execute LLM and return the result.

The agent state will be passed to the function automatically so we can use the saved system prompt and model from it.

class SQLAgent:
  <...>

  def call_llm(self, state: AgentState):
    messages = state['messages']
    # adding system prompt if it's defined
    if self.system_prompt:
        messages = [SystemMessage(content=self.system_prompt)] + messages

    # calling LLM
    message = self.model.invoke(messages)

    return {'messages': [message]}

As a result, our function returns a dictionary that will be used to update the agent state. Since we used operator.add as a reducer for our state, the returned message will be appended to the list of messages stored in the state.

The next function we need is execute_function which will run our tools. If the LLM agent decides to call a tool, we will see it in themessage.tool_calls parameter.

class SQLAgent:
  <...>  

  def execute_function(self, state: AgentState):
    tool_calls = state['messages'][-1].tool_calls

    results = []
    for tool in tool_calls:
      # checking whether tool name is correct
      if not t['name'] in self.tools:
      # returning error to the agent 
      result = "Error: There's no such tool, please, try again" 
      else:
      # getting result from the tool
      result = self.tools[t['name']].invoke(t['args'])

      results.append(
        ToolMessage(
          tool_call_id=t['id'], 
          name=t['name'], 
          content=str(result)
        )
    )
    return {'messages': results}

In this function, we iterate over the tool calls returned by LLM and either invoke these tools or return the error message. In the end, our function returns the dictionary with a single key messages that will be used to update the graph state.

There’s only one function left -the function for the conditional edge that defines whether we need to execute the tool or provide the final result. It’s pretty straightforward. We just need to check whether the last message contains any tool calls.

class SQLAgent:
  <...>  

  def exists_function_calling(self, state: AgentState):
    result = state['messages'][-1]
    return len(result.tool_calls) > 0

It’s time to create an agent and LLM model for it. I will use the new OpenAI GPT 4o mini model (doc) since it’s cheaper and better performing than GPT 3.5.

import os

# setting up credentioals
os.environ["OPENAI_MODEL_NAME"]='gpt-4o-mini'  
os.environ["OPENAI_API_KEY"] = '<your_api_key>'

# system prompt
prompt = '''You are a senior expert in SQL and data analysis. 
So, you can help the team to gather needed data to power their decisions. 
You are very accurate and take into account all the nuances in data.
Your goal is to provide the detailed documentation for the table in database 
that will help users.'''

model = ChatOpenAI(model="gpt-4o-mini")
doc_agent = SQLAgent(model, [execute_sql], system=prompt)

LangGraph provides us with quite a handy feature to visualise graphs. To use it, you need to install pygraphviz .

It’s a bit tricky for Mac with M1/M2 chips, so here is the lifehack for you (source):

! brew install graphviz
! python3 -m pip install -U --no-cache-dir  
    --config-settings="--global-option=build_ext" 
    --config-settings="--global-option=-I$(brew --prefix graphviz)/include/" 
    --config-settings="--global-option=-L$(brew --prefix graphviz)/lib/" 
    pygraphviz

After figuring out the installation, here’s our graph.

from IPython.display import Image
Image(doc_agent.graph.get_graph().draw_png())

As you can see, our graph has cycles. Implementing something like this with LCEL would be quite challenging.

Finally, it’s time to execute our agent. We need to pass the initial set of messages with our questions as HumanMessage.

messages = [HumanMessage(content="What info do we have in ecommerce_db.users table?")]
result = doc_agent.graph.invoke({"messages": messages})

In the result variable, we can observe all the messages generated during execution. The process worked as expected:

  • The agent decided to call the function with the query describe ecommerce.db_users.
  • LLM then processed the information from the tool and provided a user-friendly answer.
result['messages']

# [
#   HumanMessage(content='What info do we have in ecommerce_db.users table?'), 
#   AIMessage(content='', tool_calls=[{'name': 'execute_sql', 'args': {'query': 'DESCRIBE ecommerce_db.users;'}, 'id': 'call_qZbDU9Coa2tMjUARcX36h0ax', 'type': 'tool_call'}]), 
#   ToolMessage(content='user_idtUInt64tttttncountrytStringtttttnis_activetUInt8tttttnagetUInt64tttttn', name='execute_sql', tool_call_id='call_qZbDU9Coa2tMjUARcX36h0ax'), 
#   AIMessage(content='The `ecommerce_db.users` table contains the following columns: <...>')
# ]

Here’s the final result. It looks pretty decent.

print(result['messages'][-1].content)

# The `ecommerce_db.users` table contains the following columns:
# 1. **user_id**: `UInt64` - A unique identifier for each user.
# 2. **country**: `String` - The country where the user is located.
# 3. **is_active**: `UInt8` - Indicates whether the user is active (1) or inactive (0).
# 4. **age**: `UInt64` - The age of the user.

Using prebuilt agents

We’ve learned how to build an agent from scratch. However, we can leverage LangGraph’s built-in functionality for simpler tasks like this one.

We can use a prebuilt ReAct agent to get a similar result: an agent that can work with tools.

from langgraph.prebuilt import create_react_agent
prebuilt_doc_agent = create_react_agent(model, [execute_sql],
  state_modifier = system_prompt)

It is the same agent as we built previously. We will try it out in a second, but first, we need to understand two other important concepts: persistence and streaming.

Persistence and streaming

Persistence refers to the ability to maintain context across different interactions. It’s essential for agentic use cases when an application can get additional input from the user.

LangGraph automatically saves the state after each step, allowing you to pause or resume execution. This capability supports the implementation of advanced business logic, such as error recovery or human-in-the-loop interactions.

The easiest way to add persistence is to use an in-memory SQLite database.

from langgraph.checkpoint.sqlite import SqliteSaver
memory = SqliteSaver.from_conn_string(":memory:")

For the off-the-shelf agent, we can pass memory as an argument while creating an agent.

prebuilt_doc_agent = create_react_agent(model, [execute_sql], 
  checkpointer=memory)

If you’re working with a custom agent, you need to pass memory as a check pointer while compiling a graph.

class SQLAgent:
  def __init__(self, model, tools, system_prompt = ""):
    <...>
    self.graph = graph.compile(checkpointer=memory)
    <...>

Let’s execute the agent and explore another feature of LangGraph: streaming. With streaming, we can receive results from each step of execution as a separate event in a stream. This feature is crucial for production applications when multiple conversations (or threads) need to be processed simultaneously.

LangGraph supports not only event streaming but also token-level streaming. The only use case I have in mind for token streaming is to display the answers in real-time word by word (similar to ChatGPT implementation).

Let’s try using streaming with our new prebuilt agent. I will also use the pretty_print function for messages to make the result more readable.


# defining thread
thread = {"configurable": {"thread_id": "1"}}
messages = [HumanMessage(content="What info do we have in ecommerce_db.users table?")]

for event in prebuilt_doc_agent.stream({"messages": messages}, thread):
    for v in event.values():
        v['messages'][-1].pretty_print()

# ================================== Ai Message ==================================
# Tool Calls:
#  execute_sql (call_YieWiChbFuOlxBg8G1jDJitR)
#  Call ID: call_YieWiChbFuOlxBg8G1jDJitR
#   Args:
#     query: SELECT * FROM ecommerce_db.users LIMIT 1;
# ================================= Tool Message =================================
# Name: execute_sql
# 1000001 United Kingdom 0 70
# 
# ================================== Ai Message ==================================
# 
# The `ecommerce_db.users` table contains at least the following information for users:
# 
# - **User ID** (e.g., `1000001`)
# - **Country** (e.g., `United Kingdom`)
# - **Some numerical value** (e.g., `0`)
# - **Another numerical value** (e.g., `70`)
# 
# The specific meaning of the numerical values and additional columns 
# is not clear from the single row retrieved. Would you like more details 
# or a broader query?

Interestingly, the agent wasn’t able to provide a good enough result. Since the agent didn’t look up the table schema, it struggled to guess all columns’ meanings. We can improve the result by using follow-up questions in the same thread.


followup_messages = [HumanMessage(content="I would like to know the column names and types. Maybe you could look it up in database using describe.")]

for event in prebuilt_doc_agent.stream({"messages": followup_messages}, thread):
    for v in event.values():
        v['messages'][-1].pretty_print()

# ================================== Ai Message ==================================
# Tool Calls:
#   execute_sql (call_sQKRWtG6aEB38rtOpZszxTVs)
#  Call ID: call_sQKRWtG6aEB38rtOpZszxTVs
#   Args:
#     query: DESCRIBE ecommerce_db.users;
# ================================= Tool Message =================================
# Name: execute_sql
# 
# user_id UInt64     
# country String     
# is_active UInt8     
# age UInt64     
# 
# ================================== Ai Message ==================================
# 
# The `ecommerce_db.users` table has the following columns along with their data types:
# 
# | Column Name | Data Type |
# |-------------|-----------|
# | user_id     | UInt64    |
# | country     | String    |
# | is_active   | UInt8     |
# | age         | UInt64    |
# 
# If you need further information or assistance, feel free to ask!

This time, we got the full answer from the agent. Since we provided the same thread, the agent was able to get the context from the previous discussion. That’s how persistence works.

Let’s try to change the thread and ask the same follow-up question.

new_thread = {"configurable": {"thread_id": "42"}}
followup_messages = [HumanMessage(content="I would like to know the column names and types. Maybe you could look it up in database using describe.")]

for event in prebuilt_doc_agent.stream({"messages": followup_messages}, new_thread):
    for v in event.values():
        v['messages'][-1].pretty_print()

# ================================== Ai Message ==================================
# Tool Calls:
#   execute_sql (call_LrmsOGzzusaLEZLP9hGTBGgo)
#  Call ID: call_LrmsOGzzusaLEZLP9hGTBGgo
#   Args:
#     query: DESCRIBE your_table_name;
# ================================= Tool Message =================================
# Name: execute_sql
# 
# Database returned the following error:
# Code: 60. DB::Exception: Table default.your_table_name does not exist. (UNKNOWN_TABLE) (version 23.12.1.414 (official build))
# 
# ================================== Ai Message ==================================
# 
# It seems that the table `your_table_name` does not exist in the database. 
# Could you please provide the actual name of the table you want to describe?

It was not surprising that the agent lacked the context needed to answer our question. Threads are designed to isolate different conversations, ensuring that each thread maintains its own context.

In real-life applications, managing memory is essential. Conversations might become pretty lengthy, and at some point, it won’t be practical to pass the whole history to LLM every time. Therefore, it’s worth trimming or filtering messages. We won’t go deep into the specifics here, but you can find guidance on it in the LangGraph documentation. Another option to compress the conversational history is using summarization (example).

We’ve learned how to build systems with single agents using LangGraph. The next step is to combine multiple agents in one application.

Multi-Agent Systems

As an example of a multi-agent workflow, I would like to build an application that can handle questions from various domains. We will have a set of expert agents, each specializing in different types of questions, and a router agent that will find the best-suited expert to address each query. Such an application has numerous potential use cases: from automating customer support to answering questions from colleagues in internal chats.

First, we need to create the agent state – the information that will help agents to solve the question together. I will use the following fields:

  • question – initial customer request;
  • question_type – the category that defines which agent will be working on the request;
  • answer – the proposed answer to the question;
  • feedback – a field for future use that will gather some feedback.
class MultiAgentState(TypedDict):
    question: str
    question_type: str
    answer: str
    feedback: str

I don’t use any reducers, so our state will store only the latest version of each field.

Then, let’s create a router node. It will be a simple LLM model that defines the category of question (database, LangChain or general questions).

question_category_prompt = '''You are a senior specialist of analytical support. Your task is to classify the incoming questions. 
Depending on your answer, question will be routed to the right team, so your task is crucial for our team. 
There are 3 possible question types: 
- DATABASE - questions related to our database (tables or fields)
- LANGCHAIN- questions related to LangGraph or LangChain libraries
- GENERAL - general questions
Return in the output only one word (DATABASE, LANGCHAIN or  GENERAL).
'''

def router_node(state: MultiAgentState):
  messages = [
    SystemMessage(content=question_category_prompt), 
    HumanMessage(content=state['question'])
  ]
  model = ChatOpenAI(model="gpt-4o-mini")
  response = model.invoke(messages)
  return {"question_type": response.content}

Now that we have our first node – the router – let’s build a simple graph to test the workflow.

memory = SqliteSaver.from_conn_string(":memory:")

builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)

builder.set_entry_point("router")
builder.add_edge('router', END)

graph = builder.compile(checkpointer=memory)

Let’s test our workflow with different types of questions to see how it performs in action. This will help us evaluate whether the router agent correctly assigns questions to the appropriate expert agents.

thread = {"configurable": {"thread_id": "1"}}
for s in graph.stream({
    'question': "Does LangChain support Ollama?",
}, thread):
    print(s)

# {'router': {'question_type': 'LANGCHAIN'}}

thread = {"configurable": {"thread_id": "2"}}
for s in graph.stream({
    'question': "What info do we have in ecommerce_db.users table?",
}, thread):
    print(s)
# {'router': {'question_type': 'DATABASE'}}

thread = {"configurable": {"thread_id": "3"}}
for s in graph.stream({
    'question': "How are you?",
}, thread):
    print(s)

# {'router': {'question_type': 'GENERAL'}}

It’s working well. I recommend you build complex graphs incrementally and test each step independently. With such an approach, you can ensure that each iteration works expectedly and can save you a significant amount of debugging time.

Next, let’s create nodes for our expert agents. We will use the ReAct agent with the SQL tool we previously built as the database agent.

# database expert
sql_expert_system_prompt = '''
You are an expert in SQL, so you can help the team 
to gather needed data to power their decisions. 
You are very accurate and take into account all the nuances in data. 
You use SQL to get the data before answering the question.
'''

def sql_expert_node(state: MultiAgentState):
    model = ChatOpenAI(model="gpt-4o-mini")
    sql_agent = create_react_agent(model, [execute_sql],
        state_modifier = sql_expert_system_prompt)
    messages = [HumanMessage(content=state['question'])]
    result = sql_agent.invoke({"messages": messages})
    return {'answer': result['messages'][-1].content}

For LangChain-related questions, we will use the ReAct agent. To enable the agent to answer questions about the library, we will equip it with a search engine tool. I chose Tavily for this purpose as it provides the search results optimised for LLM applications.

If you don’t have an account, you can register to use Tavily for free (up to 1K requests per month). To get started, you will need to specify the Tavily API key in an environment variable.

# search expert 
from langchain_community.tools.tavily_search import TavilySearchResults
os.environ["TAVILY_API_KEY"] = 'tvly-...'
tavily_tool = TavilySearchResults(max_results=5)

search_expert_system_prompt = '''
You are an expert in LangChain and other technologies. 
Your goal is to answer questions based on results provided by search.
You don't add anything yourself and provide only information baked by other sources. 
'''

def search_expert_node(state: MultiAgentState):
    model = ChatOpenAI(model="gpt-4o-mini")
    sql_agent = create_react_agent(model, [tavily_tool],
        state_modifier = search_expert_system_prompt)
    messages = [HumanMessage(content=state['question'])]
    result = sql_agent.invoke({"messages": messages})
    return {'answer': result['messages'][-1].content}

For general questions, we will leverage a simple LLM model without specific tools.

# general model
general_prompt = '''You're a friendly assistant and your goal is to answer general questions.
Please, don't provide any unchecked information and just tell that you don't know if you don't have enough info.
'''

def general_assistant_node(state: MultiAgentState):
    messages = [
        SystemMessage(content=general_prompt), 
        HumanMessage(content=state['question'])
    ]
    model = ChatOpenAI(model="gpt-4o-mini")
    response = model.invoke(messages)
    return {"answer": response.content}

The last missing bit is a conditional function for routing. This will be quite straightforward—we just need to propagate the question type from the state defined by the router node.

def route_question(state: MultiAgentState):
    return state['question_type']

Now, it’s time to create our graph.

builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.add_node('database_expert', sql_expert_node)
builder.add_node('langchain_expert', search_expert_node)
builder.add_node('general_assistant', general_assistant_node)
builder.add_conditional_edges(
    "router", 
    route_question,
    {'DATABASE': 'database_expert', 
     'LANGCHAIN': 'langchain_expert', 
     'GENERAL': 'general_assistant'}
)

builder.set_entry_point("router")
builder.add_edge('database_expert', END)
builder.add_edge('langchain_expert', END)
builder.add_edge('general_assistant', END)
graph = builder.compile(checkpointer=memory)

Now, we can test the setup on a couple of questions to see how well it performs.

thread = {"configurable": {"thread_id": "2"}}
results = []
for s in graph.stream({
  'question': "What info do we have in ecommerce_db.users table?",
}, thread):
  print(s)
  results.append(s)
print(results[-1]['database_expert']['answer'])

# The `ecommerce_db.users` table contains the following columns:
# 1. **User ID**: A unique identifier for each user.
# 2. **Country**: The country where the user is located.
# 3. **Is Active**: A flag indicating whether the user is active (1 for active, 0 for inactive).
# 4. **Age**: The age of the user.
# Here are some sample entries from the table:
# 
# | User ID | Country        | Is Active | Age |
# |---------|----------------|-----------|-----|
# | 1000001 | United Kingdom  | 0         | 70  |
# | 1000002 | France         | 1         | 87  |
# | 1000003 | France         | 1         | 88  |
# | 1000004 | Germany        | 1         | 25  |
# | 1000005 | Germany        | 1         | 48  |
# 
# This gives an overview of the user data available in the table.

Good job! It gives a relevant result for the database-related question. Let’s try asking about LangChain.


thread = {"configurable": {"thread_id": "42"}}
results = []
for s in graph.stream({
    'question': "Does LangChain support Ollama?",
}, thread):
    print(s)
    results.append(s)

print(results[-1]['langchain_expert']['answer'])

# Yes, LangChain supports Ollama. Ollama allows you to run open-source 
# large language models, such as Llama 2, locally, and LangChain provides 
# a flexible framework for integrating these models into applications. 
# You can interact with models run by Ollama using LangChain, and there are 
# specific wrappers and tools available for this integration.
# 
# For more detailed information, you can visit the following resources:
# - [LangChain and Ollama Integration](https://js.langchain.com/v0.1/docs/integrations/llms/ollama/)
# - [ChatOllama Documentation](https://js.langchain.com/v0.2/docs/integrations/chat/ollama/)
# - [Medium Article on Ollama and LangChain](https://medium.com/@abonia/ollama-and-langchain-run-llms-locally-900931914a46)

Fantastic! Everything is working well, and it’s clear that Tavily’s search is effective for LLM applications.

Adding human-in-the-loop interactions

We’ve done an excellent job creating a tool to answer questions. However, in many cases, it’s beneficial to keep a human in the loop to approve proposed actions or provide additional feedback. Let’s add a step where we can collect feedback from a human before returning the final result to the user.

The simplest approach is to add two additional nodes:

  • A human node to gather feedback,
  • An editor node to revisit the answer, taking into account the feedback.

Let’s create these nodes:

  • Human node: This will be a dummy node, and it won’t perform any actions.
  • Editor node: This will be an LLM model that receives all the relevant information (customer question, draft answer and provided feedback) and revises the final answer.
def human_feedback_node(state: MultiAgentState):
    pass

editor_prompt = '''You're an editor and your goal is to provide the final answer to the customer, taking into account the feedback. 
You don't add any information on your own. You use friendly and professional tone.
In the output please provide the final answer to the customer without additional comments.
Here's all the information you need.

Question from customer: 
----
{question}
----
Draft answer:
----
{answer}
----
Feedback: 
----
{feedback}
----
'''

def editor_node(state: MultiAgentState):
  messages = [
    SystemMessage(content=editor_prompt.format(question = state['question'], answer = state['answer'], feedback = state['feedback']))
  ]
  model = ChatOpenAI(model="gpt-4o-mini")
  response = model.invoke(messages)
  return {"answer": response.content}

Let’s add these nodes to our graph. Additionally, we need to introduce an interruption before the human node to ensure that the process pauses for human feedback.

builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.add_node('database_expert', sql_expert_node)
builder.add_node('langchain_expert', search_expert_node)
builder.add_node('general_assistant', general_assistant_node)
builder.add_node('human', human_feedback_node)
builder.add_node('editor', editor_node)

builder.add_conditional_edges(
  "router", 
  route_question,
  {'DATABASE': 'database_expert', 
  'LANGCHAIN': 'langchain_expert', 
  'GENERAL': 'general_assistant'}
)

builder.set_entry_point("router")

builder.add_edge('database_expert', 'human')
builder.add_edge('langchain_expert', 'human')
builder.add_edge('general_assistant', 'human')
builder.add_edge('human', 'editor')
builder.add_edge('editor', END)
graph = builder.compile(checkpointer=memory, interrupt_before = ['human'])

Now, when we run the graph, the execution will be stopped before the human node.

thread = {"configurable": {"thread_id": "2"}}

for event in graph.stream({
    'question': "What are the types of fields in ecommerce_db.users table?",
}, thread):
    print(event)

# {'question_type': 'DATABASE', 'question': 'What are the types of fields in ecommerce_db.users table?'}
# {'router': {'question_type': 'DATABASE'}}
# {'database_expert': {'answer': 'The `ecommerce_db.users` table has the following fields:nn1. **user_id**: UInt64n2. **country**: Stringn3. **is_active**: UInt8n4. **age**: UInt64'}}

Let’s get the customer input and update the state with the feedback.

user_input = input("Do I need to change anything in the answer?")
# Do I need to change anything in the answer? 
# It looks wonderful. Could you only make it a bit friendlier please?

graph.update_state(thread, {"feedback": user_input}, as_node="human")

We can check the state to confirm that the feedback has been populated and that the next node in the sequence is editor.

print(graph.get_state(thread).values['feedback'])
# It looks wonderful. Could you only make it a bit friendlier please?

print(graph.get_state(thread).next)
# ('editor',)

We can just continue the execution. Passing None as input will resume the process from the point where it was paused.

for event in graph.stream(None, thread, stream_mode="values"):
  print(event)

print(event['answer'])

# Hello! The `ecommerce_db.users` table has the following fields:
# 1. **user_id**: UInt64
# 2. **country**: String
# 3. **is_active**: UInt8
# 4. **age**: UInt64
# Have a nice day!

The editor took our feedback into account and added some polite words to our final message. That’s a fantastic result!

We can implement human-in-the-loop interactions in a more agentic way by equipping our editor with the Human tool.

Let’s adjust our editor. I’ve slightly changed the prompt and added the tool to the agent.

from langchain_community.tools import HumanInputRun
human_tool = HumanInputRun()

editor_agent_prompt = '''You're an editor and your goal is to provide the final answer to the customer, taking into the initial question.
If you need any clarifications or need feedback, please, use human. Always reach out to human to get the feedback before final answer.
You don't add any information on your own. You use friendly and professional tone. 
In the output please provide the final answer to the customer without additional comments.
Here's all the information you need.

Question from customer: 
----
{question}
----
Draft answer:
----
{answer}
----
'''

model = ChatOpenAI(model="gpt-4o-mini")
editor_agent = create_react_agent(model, [human_tool])
messages = [SystemMessage(content=editor_agent_prompt.format(question = state['question'], answer = state['answer']))]
editor_result = editor_agent.invoke({"messages": messages})

# Is the draft answer complete and accurate for the customer's question about the types of fields in the ecommerce_db.users table?
# Yes, but could you please make it friendlier.

print(editor_result['messages'][-1].content)
# The `ecommerce_db.users` table has the following fields:
# 1. **user_id**: UInt64
# 2. **country**: String
# 3. **is_active**: UInt8
# 4. **age**: UInt64
# 
# If you have any more questions, feel free to ask!

So, the editor reached out to the human with the question, "Is the draft answer complete and accurate for the customer’s question about the types of fields in the ecommerce_db.users table?". After receiving feedback, the editor refined the answer to make it more user-friendly.

Let’s update our main graph to incorporate the new agent instead of using the two separate nodes. With this approach, we don’t need interruptions any more.

def editor_agent_node(state: MultiAgentState):
  model = ChatOpenAI(model="gpt-4o-mini")
  editor_agent = create_react_agent(model, [human_tool])
  messages = [SystemMessage(content=editor_agent_prompt.format(question = state['question'], answer = state['answer']))]
  result = editor_agent.invoke({"messages": messages})
  return {'answer': result['messages'][-1].content}

builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.add_node('database_expert', sql_expert_node)
builder.add_node('langchain_expert', search_expert_node)
builder.add_node('general_assistant', general_assistant_node)
builder.add_node('editor', editor_agent_node)

builder.add_conditional_edges(
  "router", 
  route_question,
  {'DATABASE': 'database_expert', 
   'LANGCHAIN': 'langchain_expert', 
    'GENERAL': 'general_assistant'}
)

builder.set_entry_point("router")

builder.add_edge('database_expert', 'editor')
builder.add_edge('langchain_expert', 'editor')
builder.add_edge('general_assistant', 'editor')
builder.add_edge('editor', END)

graph = builder.compile(checkpointer=memory)

thread = {"configurable": {"thread_id": "42"}}
results = []

for event in graph.stream({
  'question': "What are the types of fields in ecommerce_db.users table?",
}, thread):
  print(event)
  results.append(event)

This graph will work similarly to the previous one. I personally prefer this approach since it leverages tools, making the solution more agile. For example, agents can reach out to humans multiple times and refine questions as needed.

That’s it. We’ve built a multi-agent system that can answer questions from different domains and take into account human feedback.

You can find the complete code on GitHub.

Summary

In this article, we’ve explored the LangGraph library and its application for building single and multi-agent workflows. We’ve examined a range of its capabilities, and now it’s time to summarise its strengths and weaknesses. Also, it will be useful to compare LangGraph with CrewAI, which we discussed in my previous article.

Overall, I find LangGraph quite a powerful framework for building complex LLM applications:

  • LangGraph is a low-level framework that offers extensive customisation options, allowing you to build precisely what you need.
  • Since LangGraph is built on top of LangChain, it’s seamlessly integrated into its ecosystem, making it easy to leverage existing tools and components.

However, there are areas where LangGrpah could be improved:

  • The agility of LangGraph comes with a higher entry barrier. While you can understand the concepts of CrewAI within 15–30 minutes, it takes some time to get comfortable and up to speed with LangGraph.
  • LangGraph provides you with a higher level of control, but it misses some cool prebuilt features of CrewAI, such as collaboration or ready-to-use RAG tools.
  • LangGraph doesn’t enforce best practices like CrewAI does (for example, role-playing or guardrails). So it can lead to poorer results.

I would say that CrewAI is a better framework for newbies and common use cases because it helps you get good results quickly and provides guidance to prevent mistakes.

If you want to build an advanced application and need more control, LangGraph is the way to go. Keep in mind that you’ll need to invest time in learning LangGraph and should be fully responsible for the final solution, as the framework won’t provide guidance to help you avoid common mistakes.

Thank you a lot for reading this article. I hope this article was insightful for you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

This article is inspired by the "AI Agents in LangGraph" short course from DeepLearning.AI.

The post From Basics to Advanced: Exploring LangGraph appeared first on Towards Data Science.

]]>
Multi AI Agent Systems 101 https://towardsdatascience.com/multi-ai-agent-systems-101-bac58e3bcc47/ Sun, 16 Jun 2024 17:03:32 +0000 https://towardsdatascience.com/multi-ai-agent-systems-101-bac58e3bcc47/ Automating Routine Tasks in Data Source Management with CrewAI

The post Multi AI Agent Systems 101 appeared first on Towards Data Science.

]]>
Initially, when ChatGPT just appeared, we used simple prompts to get answers to our questions. Then, we encountered issues with hallucinations and began using RAG (Retrieval Augmented Generation) to provide more context to LLMs. After that, we started experimenting with AI agents, where LLMs act as a reasoning engine and can decide what to do next, which tools to use, and when to return the final answer.

The next evolutionary step is to create teams of such agents that can collaborate with each other. This approach is logical as it mirrors human interactions. We work in teams where each member has a specific role:

  • The product manager proposes the next project to work on.
  • The designer creates its look and feel.
  • The software engineer develops the solution.
  • The analyst examines the data to ensure it performs as expected and identifies ways to improve the product for customers.

Similarly, we can create a team of AI agents, each focusing on one domain. They can collaborate and reach a final conclusion together. Just as specialization enhances performance in real life, it could also benefit the performance of AI agents.

Another advantage of this approach is increased flexibility. Each agent can operate with its own prompt, set of tools and even LLM. For instance, we can use different models for different parts of our system. You can use GPT-4 for the agent that needs more reasoning and GPT-3.5 for the one that does only simple extraction. We can even fine-tune the model for small specific tasks and use it in our crew of agents.

The potential drawbacks of this approach are time and cost. Multiple interactions and knowledge sharing between agents require more calls to LLM and consume additional tokens. This could result in longer wait times and increased expenses.

There are several frameworks available for multi-agent systems today. Here are some of the most popular ones:

  • AutoGen: Developed by Microsoft, AutoGen uses a conversational approach and was one of the earliest frameworks for multi-agent systems,
  • LangGraph: While not strictly a multi-agent framework, LangGraph allows for defining complex interactions between actors using a graph structure. So, it can also be adapted to create multi-agent systems.
  • CrewAI: Positioned as a high-level framework, CrewAI facilitates the creation of "crews" consisting of role-playing agents capable of collaborating in various ways.

I’ve decided to start experimenting with multi-agent frameworks from CrewAI since it’s quite widely popular and user friendly. So, it looks like a good option to begin with.

In this article, I will walk you through how to use CrewAI. As analysts, we’re the domain experts responsible for documenting various data sources and addressing related questions. We’ll explore how to automate these tasks using multi-agent frameworks.

Setting up the environment

Let’s start with setting up the environment. First, we need to install the CrewAI main package and an extension to work with tools.

pip install crewai
pip install 'crewai[tools]'

CrewAI was developed to work primarily with OpenAI API, but I would also like to try it with a local model. According to the ChatBot Arena Leaderboard, the best model you can run on your laptop is Llama 3 (8b parameters). It will be the most feasible option for our use case.

We can access Llama models using Ollama. Installation is pretty straightforward. You need to download Ollama from the website and then go through the installation process. That’s it.

Now, you can test the model in CLI by running the following command.

ollama run llama3

For example, you can ask something like this.

Let’s create a custom Ollama model to use later in CrewAI.

We will start with a ModelFile (documentation). I only specified the base model (llama3), temperature and stop sequence. However, you might add more features. For example, you can determine the system message using SYSTEM keyword.

FROM llama3

# set parameters
PARAMETER temperature 0.5
PARAMETER stop Result

I’ve saved it into a Llama3ModelFile file.

Let’s create a bash script to load the base model for Ollama and create the custom model we defined in ModelFile.

#!/bin/zsh

# define variables
model_name="llama3"
custom_model_name="crewai-llama3"

# load the base model
ollama pull $model_name

# create the model file
ollama create $custom_model_name -f ./Llama3ModelFile

Let’s execute this file.

chmod +x ./llama3_setup.sh
./llama3_setup.sh

You can find both files on GitHub: Llama3ModelFile and llama3_setup.sh

We need to initialise the following environmental variables to use the local Llama model with CrewAI.

os.environ["OPENAI_API_BASE"]='http://localhost:11434/v1'

os.environ["OPENAI_MODEL_NAME"]='crewai-llama3' 
# custom_model_name from the bash script

os.environ["OPENAI_API_KEY"] = "NA"

We’ve finished the setup and are ready to continue our journey.

Use cases: working with documentation

As analysts, we often play the role of subject matter experts for data and some data-related tools. In my previous team, we used to have a channel with almost 1K participants, where we were answering lots of questions about our data and the ClickHouse database we used as storage. It took us quite a lot of time to manage this channel. It would be interesting to see whether such tasks can be automated with LLMs.

For this example, I will use the ClickHouse database. If you’re interested, You can learn more about ClickHouse and how to set it up locally in my previous article. However, we won’t utilise any ClickHouse-specific features, so feel free to stick to the database you know.

I’ve created a pretty simple data model to work with. There are just two tables in our DWH (Data Warehouse): ecommerce_db.users and ecommerce_db.sessions. As you might guess, the first table contains information about the users of our service.

The ecommerce_db.sessions table stores information about user sessions.

Regarding data source management, analysts typically handle tasks like writing and updating documentation and answering questions about this data. So, we will use LLM to write documentation for the table in the database and teach it to answer questions about data or ClickHouse.

But before moving on to the implementation, let’s learn more about the CrewAI framework and its core concepts.

CrewAI basic concepts

The cornerstone of a multi-agent framework is an agent concept. In CrewAI, agents are powered by role-playing. Role-playing is a tactic when you ask an agent to adopt a persona and behave like a top-notch backend engineer or helpful customer support agent. So, when creating a CrewAI agent, you need to specify each agent’s role, goal, and backstory so that LLM knows enough to play this role.

The agents’ capabilities are limited without tools (functions that agents can execute and get results). With CrewAI, you can use one of the predefined tools (for example, to search the Internet, parse a website, or do RAG on a document), create a custom tool yourself or use LangChain tools. So, it’s pretty easy to create a powerful agent.

Let’s move on from agents to the work they are doing. Agents are working on tasks (specific assignments). For each task, we need to define a description, expected output (definition of done), set of available tools and assigned agent. I really like that these frameworks follow the managerial best practices like a clear definition of done for the tasks.

The next question is how to define the execution order for tasks: which one to work on first, which ones can run in parallel, etc. CrewAI implemented processes to orchestrate the tasks. It provides a couple of options:

  • Sequential -the most straightforward approach when tasks are called one after another.
  • Hierarchical – when there’s a manager (specified as LLM model) that creates and delegates tasks to the agents.

Also, CrewAI is working on a consensual process. In such a process, agents will be able to make decisions collaboratively with a democratic approach.

There are other levers you can use to tweak the process of tasks’ execution:

  • You can mark tasks as "asynchronous", then they will be executed in parallel, so you will be able to get an answer faster.
  • You can use the "human input" flag on a task, and then the agent will ask for human approval before finalising the output of this task. It can allow you to add an oversight to the process.

We’ve defined all the primary building blocks and can discuss the holly grail of CrewAI – crew concept. The crew represents the team of agents and the set of tasks they will be working on. The approach for collaboration (processes we discussed above) can also be defined at the crew level.

Also, we can set up the memory for a crew. Memory is crucial for efficient collaboration between the agents. CrewAI supports three levels of memory:

  • Short-term memory stores information related to the current execution. It helps agents to work together on the current task.
  • Long-term memory is data about the previous executions stored in the local database. This type of memory allows agents to learn from earlier iterations and improve over time.
  • Entity memory captures and structures information about entities (like personas, cities, etc.)

Right now, you can only switch on all types of memory for a crew without any further customisation. However, it doesn’t work with the Llama models.

We’ve learned enough about the CrewAI framework, so it’s time to start using this knowledge in practice.

Use case: writing documentation

Let’s start with a simple task: putting together the documentation for our DWH. As we discussed before, there are two tables in our DWH, and I would like to create a detailed description for them using LLMs.

First approach

In the beginning, we need to think about the team structure. Think of this as a typical managerial task. Who would you hire for such a job?

I would break this task into two parts: retrieving data from a database and writing documentation. So, we need a database specialist and a technical writer. The database specialist needs access to a database, while the writer won’t need any special tools.

Now, we have a high-level plan. Let’s create the agents.

For each agent, I’ve specified the role, goal and backstory. I’ve tried my best to provide agents with all the needed context.

database_specialist_agent = Agent(
  role = "Database specialist",
  goal = "Provide data to answer business questions using SQL",
  backstory = '''You are an expert in SQL, so you can help the team 
  to gather needed data to power their decisions. 
  You are very accurate and take into account all the nuances in data.''',
  allow_delegation = False,
  verbose = True
)

tech_writer_agent = Agent(
  role = "Technical writer",
  goal = '''Write engaging and factually accurate technical documentation 
    for data sources or tools''',
  backstory = ''' 
  You are an expert in both technology and communications, so you can easily explain even sophisticated concepts.
  You base your work on the factual information provided by your colleagues.
  Your texts are concise and can be easily understood by a wide audience. 
  You use professional but rather an informal style in your communication.
  ''',
  allow_delegation = False,
  verbose = True
)

We will use a simple sequential process, so there’s no need for agents to delegate tasks to each other. That’s why I specified allow_delegation = False.

The next step is setting the tasks for agents. But before moving to them, we need to create a custom tool to connect to the database.

First, I put together a function to execute ClickHouse queries using HTTP API.

CH_HOST = 'http://localhost:8123' # default address 

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query}, 
    timeout = connection_timeout)
  if r.status_code == 200:
      return r.text
  else: 
      return 'Database returned the following error:n' + r.text

When working with LLM agents, it’s important to make tools fault-tolerant. For example, if the database returns an error (status_code != 200), my code won’t throw an exception. Instead, it will return the error description to the LLM so it can attempt to resolve the issue.

To create a CrewAI custom tool, we need to derive our class from crewai_tools.BaseTool, implement the _run method and then create an instance of this class.

from crewai_tools import BaseTool

class DatabaseQuery(BaseTool):
  name: str = "Database Query"
  description: str = "Returns the result of SQL query execution"

  def _run(self, sql_query: str) -> str:
      # Implementation goes here
      return get_clickhouse_data(sql_query)

database_query_tool = DatabaseQuery()

Now, we can set the tasks for the agents. Again, providing clear instructions and all the context to LLM is crucial.

table_description_task = Task(
  description = '''Provide the comprehensive overview for the data 
  in table {table}, so that it's easy to understand the structure 
  of the data. This task is crucial to put together the documentation 
  for our database''',
  expected_output = '''The comprehensive overview of {table} in the md format. 
  Include 2 sections: columns (list of columns with their types) 
  and examples (the first 30 rows from table).''',
  tools = [database_query_tool],
  agent = database_specialist_agent
)

table_documentation_task = Task(
  description = '''Using provided information about the table, 
  put together the detailed documentation for this table so that 
  people can use it in practice''',
  expected_output = '''Well-written detailed documentation describing 
  the data scheme for the table {table} in markdown format, 
  that gives the table overview in 1-2 sentences then then 
  describes each columm. Structure the columns description 
  as a markdown table with column name, type and description.''',
  tools = [],
  output_file="table_documentation.md",
  agent = tech_writer_agent
)

You might have noticed that I’ve used {table} placeholder in the tasks’ descriptions. We will use table as an input variable when executing the crew, and this value will be inserted into all placeholders.

Also, I’ve specified the output file for the table documentation task to save the final result locally.

We have all we need. Now, it’s time to create a crew and execute the process, specifying the table we are interested in. Let’s try it with the users table.

crew = Crew(
  agents = [database_specialist_agent, tech_writer_agent],
  tasks = [table_description_task,  table_documentation_task],
  verbose = 2
)

result = crew.kickoff({'table': 'ecommerce_db.users'})

It’s an exciting moment, and I’m really looking forward to seeing the result. Don’t worry if execution takes some time. Agents make multiple LLM calls, so it’s perfectly normal for it to take a few minutes. It took 2.5 minutes on my laptop.

We asked LLM to return the documentation in markdown format. We can use the following code to see the formatted result in Jupyter Notebook.

from IPython.display import Markdown
Markdown(result)

At first glance, it looks great. We’ve got the valid markdown file describing the users’ table.

But wait, it’s incorrect. Let’s see what data we have in our table.

The columns listed in the documentation are completely different from what we have in the database. It’s a case of LLM hallucinations.

We’ve set verbose = 2 to get the detailed logs from CrewAI. Let’s read through the execution logs to identify the root cause of the problem.

First, the database specialist couldn’t query the database due to complications with quotes.

The specialist didn’t manage to resolve this problem. Finally, this chain has been terminated by CrewAI with the following output: Agent stopped due to iteration limit or time limit.

This means the technical writer didn’t receive any factual information about the data. However, the agent continued and produced completely fake results. That’s how we ended up with incorrect documentation.

Fixing the issues

Even though our first iteration wasn’t successful, we’ve learned a lot. We have (at least) two areas for improvement:

  • Our database tool is too difficult for the model, and the agent struggles to use it. We can make the tool more tolerant by removing quotes from the beginning and end of the queries. This solution is not ideal since valid SQL can end with a quote, but let’s try it.
  • Our technical writer isn’t basing its output on the input from the database specialist. We need to tweak the prompt to highlight the importance of providing only factual information.

So, let’s try to fix these problems. First, we will fix the tool – we can leverage strip to eliminate quotes.

CH_HOST = 'http://localhost:8123' # default address 

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query.strip('"').strip("'")}, 
    timeout = connection_timeout)
  if r.status_code == 200:
    return r.text
  else: 
    return 'Database returned the following error:n' + r.text

Then, it’s time to update the prompt. I’ve included statements emphasizing the importance of sticking to the facts in both the agent and task definitions.


tech_writer_agent = Agent(
  role = "Technical writer",
  goal = '''Write engaging and factually accurate technical documentation 
  for data sources or tools''',
  backstory = ''' 
  You are an expert in both technology and communications, so you 
  can easily explain even sophisticated concepts.
  Your texts are concise and can be easily understood by wide audience. 
  You use professional but rather informal style in your communication.
  You base your work on the factual information provided by your colleagues. 
  You stick to the facts in the documentation and use ONLY 
  information provided by the colleagues not adding anything.''',
  allow_delegation = False,
  verbose = True
)

table_documentation_task = Task(
  description = '''Using provided information about the table, 
  put together the detailed documentation for this table so that 
  people can use it in practice''',
  expected_output = '''Well-written detailed documentation describing 
  the data scheme for the table {table} in markdown format, 
  that gives the table overview in 1-2 sentences then then 
  describes each columm. Structure the columns description 
  as a markdown table with column name, type and description.
  The documentation is based ONLY on the information provided 
  by the database specialist without any additions.''',
  tools = [],
  output_file = "table_documentation.md",
  agent = tech_writer_agent
)

Let’s execute our crew once again and see the results.

We’ve achieved a bit better result. Our database specialist was able to execute queries and view the data, which is a significant win for us. Additionally, we can see all the relevant fields in the result table, though there are lots of other fields as well. So, it’s still not entirely correct.

I once again looked through the CrewAI execution log to figure out what went wrong. The issue lies in getting the list of columns. There’s no filter by database, so it returns some unrelated columns that appear in the result.

SELECT column_name 
FROM information_schema.columns 
WHERE table_name = 'users'

Also, after looking at multiple attempts, I noticed that the database specialist, from time to time, executes select * from <table> query. It might cause some issues in production as it might generate lots of data and send it to LLM.

More specialised tools

We can provide our agent with more specialised tools to improve our solution. Currently, the agent has a tool to execute any SQL query, which is flexible and powerful but prone to errors. We can create more focused tools, such as getting table structure and top-N rows from the table. Hopefully, it will reduce the number of mistakes.

class TableStructure(BaseTool):
  name: str = "Table structure"
  description: str = "Returns the list of columns and their types"

  def _run(self, table: str) -> str:
    table = table.strip('"').strip("'")
    return get_clickhouse_data(
      'describe {table} format TabSeparatedWithNames'
        .format(table = table)
    )

class TableExamples(BaseTool):
  name: str = "Table examples"
  description: str = "Returns the first N rows from the table"

  def _run(self, table: str, n: int = 30) -> str:
    table = table.strip('"').strip("'")
    return get_clickhouse_data(
      'select * from {table} limit {n} format TabSeparatedWithNames'
        .format(table = table, n = n)
    )

table_structure_tool = TableStructure()
table_examples_tool = TableExamples()

Now, we need to specify these tools in the task and re-run our script. After the first attempt, I got the following output from the Technical Writer.

Task output: This final answer provides a detailed and factual description 
of the ecommerce_db.users table structure, including column names, types, 
and descriptions. The documentation adheres to the provided information 
from the database specialist without any additions or modifications.

More focused tools helped the database specialist retrieve the correct table information. However, even though the writer had all the necessary information, we didn’t get the expected result.

As we know, LLMs are probabilistic, so I gave it another try. And hooray, this time, the result was pretty good.

It’s not perfect since it still includes some irrelevant comments and lacks the overall description of the table. However, providing more specialised tools has definitely paid off. It also helped to prevent issues when the agent tried to load all the data from the table.

Quality assurance specialist

We’ve achieved pretty good results, but let’s see if we can improve them further. A common practice in multi-agent setups is quality assurance, which adds the final review stage before finalising the results.

Let’s create a new agent – a Quality Assurance Specialist, who will be in charge of review.

qa_specialist_agent = Agent(
  role = "Quality Assurance specialist",
  goal = """Ensure the highest quality of the documentation we provide 
  (that it's correct and easy to understand)""",
  backstory = '''
  You work as a Quality Assurance specialist, checking the work 
  from the technical writer and ensuring that it's inline 
  with our highest standards.
  You need to check that the technical writer provides the full complete 
  answers and make no assumptions. 
  Also, you need to make sure that the documentation addresses 
  all the questions and is easy to understand.
  ''',
  allow_delegation = False,
  verbose = True
)

Now, it’s time to describe the review task. I’ve used the context parameter to specify that this task requires outputs from both table_description_task and table_documentation_task.

qa_review_task = Task(
  description = '''
  Review the draft documentation provided by the technical writer.
  Ensure that the documentation fully answers all the questions: 
  the purpose of the table and its structure in the form of table. 
  Make sure that the documentation is consistent with the information 
  provided by the database specialist. 
  Double check that there are no irrelevant comments in the final version 
  of documentation.
  ''',
  expected_output = '''
  The final version of the documentation in markdown format 
  that can be published. 
  The documentation should fully address all the questions, be consistent 
  and follow our professional but informal tone of voice.
  ''',
  tools = [],
  context = [table_description_task, table_documentation_task],
  output_file="checked_table_documentation.md",
  agent = qa_specialist_agent
)

Let’s update our crew and run it.

full_crew = Crew(
  agents=[database_specialist_agent, tech_writer_agent, qa_specialist_agent],
  tasks=[table_description_task,  table_documentation_task, qa_review_task],
  verbose = 2,
  memory = False # don't work with Llama
)

full_result = full_crew.kickoff({'table': 'ecommerce_db.users'})

We now have more structured and detailed documentation thanks to the addition of the QA stage.

Delegation

With the addition of the QA specialist, it would be interesting to test the delegation mechanism. The QA specialist agent might have questions or requests that it could delegate to other agents.

I tried using the delegation with Llama 3, but it didn’t go well. Llama 3 struggled to call the co-worker tool correctly. It couldn’t specify the correct co-worker’s name.

We achieved pretty good results with a local model that can run on any laptop, but now it’s time to switch gears and use a way more powerful model – GPT-4o.

To do it, we just need to update the following environment variables.

os.environ["OPENAI_MODEL_NAME"] = 'gpt-4o'  
os.environ["OPENAI_API_KEY"] = config['OPENAI_API_KEY'] # your OpenAI key

To switch on the delegation, we should specify allow_delegation = True for the QA specialist agent.

Also, we can use handy memory functionality for OpenAI models, as mentioned above. The memory will allow our agents to share their knowledge with each other during execution and leverage long-term memory to get information from previous executions.

full_crew = Crew(
  agents = [database_specialist_agent, tech_writer_agent, qa_specialist_agent],
  tasks = [table_description_task,  table_documentation_task, qa_review_task],
  verbose = 2,
  memory = True 
)

full_result = full_crew.kickoff({'table': 'ecommerce_db.users'})

Let’s see the CrewAI execution logs to understand how delegation works. Here are all the logs for the QA specialist. We can see that it reached out to the database specialist to double-check the information.

Thought: I need to review the provided draft documentation to ensure it 
fully answers all the questions and is complete, correct, and easy 
to understand. Specifically, I need to ensure that it explains the purpose 
of the table, its structure, and includes relevant examples without 
any irrelevant comments.

First, I will compare the two drafts to identify any discrepancies and 
missing information. Then, I will verify the accuracy and completeness of 
the information provided.

**Comparison:**
1. **Introduction:**
   - First draft: Includes an introduction and examples.
   - Second draft: Adds a more detailed introduction about the purpose and importance of the table.

2. **Columns:**
   - Both drafts have the same column structure and descriptions.

3. **Additional Information:**
   - Both drafts include similar additional information, but the second draft is more concise.

4. **Examples:**
   - The first draft includes examples, which are missing in the second draft.

5. **Final Checks:**
   - Ensure there are no irrelevant comments.
   - Verify consistency with information provided by the database specialist.

Action: Ask question to co-worker
Action Input: {
"question": "Can you confirm the accuracy of the provided column descriptions 
  and examples for the `ecommerce_db.users` table?", 
"context": "We are finalizing the documentation for the `ecommerce_db.users`
 table. The table includes columns `user_id`, `country`, `is_active`, and 
  `age`. Here are the examples provided:

  | user_id | country        | is_active | age |
  |---------|----------------|-----------|-----|
  | 1000001 | United Kingdom | 0         | 70  |
  | 1000002 | France         | 1         | 87  |
  | 1000003 | France         | 1         | 88  |
  | 1000004 | Germany        | 1         | 25  |
  | 1000005 | Germany        | 1         | 48  |
  | 1000006 | France         | 1         | 78  |
  | 1000007 | Netherlands    | 1         | 65  |
  | 1000008 | Germany        | 1         | 31  |
  | 1000009 | Germany        | 1         | 66  |
  | 1000010 | Germany        | 1         | 73  |", 
"coworker": "database specialist"}

<-- omitted logs from the database specialist -->

Based on the provided context and examples for the `ecommerce_db.users` table, 
the column descriptions and examples appear accurate. Here are the details:
- `user_id`: This column contains a unique identifier for each user. 
  The provided examples (e.g., 1000001, 1000002) align with this description 
  as they are unique numeric identifiers.

- `country`: This column specifies the country of the user. The provided 
  examples (e.g., United Kingdom, France, Germany, Netherlands) are 
  all valid country names.

- `is_active`: This column indicates whether the user is active (1) or not (0). The provided examples (e.g., 0, 1) correctly represent this binary status.

- `age`: This column shows the age of the user. The provided examples (e.g., 70, 
  87, 88, 25, 48, 78, 65, 31, 66, 73) are all valid ages.

Therefore, the column descriptions and examples provided for the 
 `ecommerce_db.users` table are accurate and meet the expected criteria.

Thought: Based on the observation from the database specialist, 
the column descriptions and examples provided for the `ecommerce_db.users` 
table are accurate. Now, I will consolidate the best elements from both drafts,
 ensuring the documentation is complete, clear, and free of irrelevant comments.

Action: Finalize the documentation in markdown format, incorporating 
the detailed introduction, examples, and ensuring the tone is professional 
but informal.

When I tried the delegation for the first time, I didn’t enable memory, which led to incorrect results. The data specialist and the technical writer initially returned the correct information. However, when the QA specialist returned with the follow-up questions, they started to hallucinate. So, it looks like delegation works better when memory is enabled.

Here’s the final output from GPT-4o. The result looks pretty nice now. We definitely can use LLMs to automate documentation.

So, the first task has been solved!

I used the same script to generate documentation for the ecommerce_db.sessions table as well. It will be handy for our next task. So, let’s not waste any time and move on.

Use case: answering questions

Our next task is answering questions based on the documentation since it’s common for many data analysts (and other specialists).

We will start simple and will create just two agents:

  • The documentation support specialist will be answering questions based on the docs,
  • The support QA agent will review the answer before sharing it with the customer.

We will need to empower the documentation specialist with a couple of tools that will allow them to see all the files stored in the directory and read the files. It’s pretty straightforward since CrewAI has implemented such tools.

from crewai_tools import DirectoryReadTool, FileReadTool

documentation_directory_tool = DirectoryReadTool(
    directory = '~/crewai_project/ecommerce_documentation')

base_file_read_tool = FileReadTool()

However, since Llama 3 keeps struggling with quotes when calling tools, I had to create a custom tool on top of the FileReaderTool to overcome this issue.

from crewai_tools import BaseTool

class FileReadToolUPD(BaseTool):
    name: str = "Read a file's content"
    description: str = "A tool that can be used to read a file's content."

    def _run(self, file_path: str) -> str:
        # Implementation goes here
        return base_file_read_tool._run(file_path = file_path.strip('"').strip("'"))

file_read_tool = FileReadToolUPD()

Next, as we did before, we need to create agents, tasks and crew.

data_support_agent = Agent(
  role = "Senior Data Support Agent",
  goal = "Be the most helpful support for you colleagues",
  backstory = '''You work as a support for data-related questions 
  in the company. 
  Even though you're a big expert in our data warehouse, you double check 
  all the facts in documentation. 
  Our documentation is absolutely up-to-date, so you can fully rely on it 
  when answering questions (you don't need to check the actual data 
  in database).
  Your work is very important for the team success. However, remember 
  that examples of table rows don't show all the possible values. 
  You need to ensure that you provide the best possible support: answering 
  all the questions, making no assumptions and sharing only the factual data.
  Be creative try your best to solve the customer problem. 
  ''',
  allow_delegation = False,
  verbose = True
)

qa_support_agent = Agent(
  role = "Support Quality Assurance Agent",
  goal = """Ensure the highest quality of the answers we provide 
  to the customers""",
  backstory = '''You work as a Quality Assurance specialist, checking the work 
  from support agents and ensuring that it's inline with our highest standards.
  You need to check that the agent provides the full complete answers 
  and make no assumptions. 
  Also, you need to make sure that the documentation addresses all 
  the questions and is easy to understand.
  ''',
  allow_delegation = False,
  verbose = True
)

draft_data_answer = Task(
  description = '''Very important customer {customer} reached out to you 
  with the following question:

{question}


  Your task is to provide the best answer to all the points in the question 
  using all available information and not making any assumprions. 
  If you don't have enough information to answer the question, just say 
  that you don't know.''',
  expected_output = '''The detailed informative answer to the customer's 
  question that addresses all the point mentioned. 
  Make sure that answer is complete and stict to facts 
  (without any additional information not based on the factual data)''',
  tools = [documentation_directory_tool, file_read_tool], 
  agent = data_support_agent
)

answer_review = Task(
  description = '''
  Review the draft answer provided by the support agent.
  Ensure that the it fully answers all the questions mentioned 
  in the initial inquiry. 
  Make sure that the answer is consistent and doesn't include any assumptions.
  ''',
  expected_output = '''
  The final version of the answer in markdown format that can be shared 
  with the customer. 
  The answer should fully address all the questions, be consistent 
  and follow our professional but informal tone of voice. 
  We are very chill and friendly company, so don't forget to include 
  all the polite phrases.
  ''',
  tools = [], 
  agent = qa_support_agent
)

qna_crew = Crew(
  agents = [data_support_agent, qa_support_agent],
  tasks = [draft_data_answer,  answer_review],
  verbose = 2,
  memory = False # don't work with Llama
)

Let’s see how it works in practice.

result = qna_crew.kickoff(
  {'customer': "Max", 
   'question': """Hey team, I hope you're doing well. I need to find 
    the numbers before our CEO presentation tomorrow, so I will really 
    appreciate your help.
    I need to calculate the number of sessions from our Windows users in 2023. I've tried to find the table with such data in our data warehouse, but wasn't able to. 
    Do you have any ideas whether we store the needed data somewhere, 
    so that I can query it? """
  }
)

We’ve got a polite, practical and helpful answer in return. That’s really great.

**Hello Max,**

Thank you for reaching out with your question! I'm happy to help you 
find the number of sessions from Windows users in 2023. 
After reviewing our documentation, I found that we do store data 
related to sessions and users in our ecommerce database, specifically in 
the `ecommerce_db.sessions` table.

To answer your question, I can provide you with a step-by-step guide 
on how to query this table using SQL. First, you can use the `session_id` 
column along with the `os` column filtering for "Windows" and 
the `action_date` column filtering for dates in 2023. 
Then, you can group the results by `os` using the `GROUP BY` clause 
to count the number of sessions that meet these conditions.

Here's a sample SQL query that should give you the desired output:

```sql
SELECT COUNT(*) 
FROM ecommerce_db.sessions 
WHERE os = 'Windows' 
AND action_date BETWEEN '2023-01-01' AND '2023-12-31'
GROUP BY os;

This query will return the total number of sessions from Windows users in 2023. I hope this helps! If you have any further questions or need more assistance, please don’t hesitate to ask.


Let's complicate the task a bit. Suppose we can get not only questions about our data but also about our tool (ClickHouse). So, we will have another agent in the crew - ClickHouse Guru. To give our CH agent some knowledge, I will share a documentation website with it.

```java
from crewai_tools import ScrapeWebsiteTool, WebsiteSearchTool
ch_documenation_tool = ScrapeWebsiteTool(
  'https://clickhouse.com/docs/en/guides/creating-tables')

If you need to work with a lengthy document, you might try using RAG (Retrieval Augmented generation) – WebsiteSearchTool. It will calculate embeddings and store them locally in ChromaDB. In our case, we will stick to a simple website scraper tool.

Now that we have two subject matter experts, we need to decide who will be working on the questions. So, it’s time to use a hierarchical process and add a manager to orchestrate all the tasks.

CrewAI provides the manager implementation, so we only need to specify the LLM model. I’ve picked the GPT-4o.

from langchain_openai import ChatOpenAI
from crewai import Process

complext_qna_crew = Crew(
  agents = [ch_support_agent, data_support_agent, qa_support_agent],
  tasks = [draft_ch_answer, draft_data_answer, answer_review],
  verbose = 2,
  manager_llm = ChatOpenAI(model='gpt-4o', temperature=0),  
  process = Process.hierarchical,  
  memory = False 
)

At this point, I had to switch from Llama 3 to OpenAI models again to run a hierarchical process since it hasn’t worked for me with Llama (similar to this issue).

Now, we can try our new crew with different types of questions (either related to our data or ClickHouse database).

ch_result = complext_qna_crew.kickoff(
  {'customer': "Maria", 
   'question': """Good morning, team. I'm using ClickHouse to calculate 
   the number of customers. 
   Could you please remind whether there's an option to add totals 
   in ClickHouse?"""
  }
)

doc_result = complext_qna_crew.kickoff(
  {'customer': "Max", 
   'question': """Hey team, I hope you're doing well. I need to find 
    the numbers before our CEO presentation tomorrow, so I will really 
    appreciate your help.
    I need to calculate the number of sessions from our Windows users 
    in 2023. I've tried to find the table with such data 
    in our data warehouse, but wasn't able to. 
    Do you have any ideas whether we store the needed data somewhere, 
    so that I can query it. """
  }
)

If we look at the final answers and logs (_I’ve omitted them here since they are quite lengthy, bu_t _you can find them and full logs on GitHub_), we will see that the manager was able to orchestrate correctly and delegate tasks to co-workers with relevant knowledge to address the customer’s question. For the first (ClickHouse-related) question, we got a detailed answer with examples and possible implications of using WITH TOTALS functionality. For the data-related question, models returned roughly the same information as we’ve seen above.

So, we’ve built a crew that can answer various types of questions based on the documentation, whether from a local file or a website. I think it’s an excellent result.

You can find all the code on GitHub.

Summary

In this article, we’ve explored using the CrewAI multi-agent framework to create a solution for writing documentation based on tables and answering related questions.

Given the extensive functionality we’ve utilised, it’s time to summarise the strengths and weaknesses of this framework.

Overall, I find CrewAI to be an incredibly useful framework for multi-agent systems:

  • It’s straightforward, and you can build your first prototype quickly.
  • Its flexibility allows to solve quite sophisticated business problems.
  • It encourages good practices like role-playing.
  • It provides many handy tools out of the box, such as RAG and a website parser.
  • The support of different types of memory enhances the agents’ collaboration.
  • Built-in guardrails help prevent agents from getting stuck in repetitive loops.

However, there are areas that could be improved:

  • While the framework is simple and easy to use, it’s not very customisable. For instance, you currently can’t create your own LLM manager to orchestrate the processes.
  • Sometimes, it’s quite challenging to get the full detailed information from the documentation. For example, it’s clear that CrewAI implemented some guardrails to prevent repetitive function calls, but the documentation doesn’t fully explain how it works.
  • Another improvement area is transparency. I like to understand how frameworks work under the hood. For example, in Langchain, you can use langchain.debug = True to see all the LLM calls. However, I haven’t figured out how to get the same level of detail with CrewAI.
  • The full support for the local models would be a great addition, as the current implementation either lacks some features or is difficult to get working properly.

The domain and tools for LLMs are evolving rapidly, so I’m hopeful that we’ll see a lot of progress in the near future.

Thank you a lot for reading this article. I hope this article was insightful for you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

This article is inspired by the "Multi AI Agent Systems with CrewAI" short course from DeepLearning.AI.

The post Multi AI Agent Systems 101 appeared first on Towards Data Science.

]]>
From Code to Insights: Software Engineering Best Practices for Data Analysts https://towardsdatascience.com/from-code-to-insights-software-engineering-best-practices-for-data-analysts-0dd6a2aaadfc/ Thu, 06 Jun 2024 19:24:17 +0000 https://towardsdatascience.com/from-code-to-insights-software-engineering-best-practices-for-data-analysts-0dd6a2aaadfc/ Top 10 engineering lessons every data analyst should know

The post From Code to Insights: Software Engineering Best Practices for Data Analysts appeared first on Towards Data Science.

]]>
The data analyst job combines skills from different domains:

  • We need to have business understanding and domain knowledge to be able to solve actual business problems and take into account all the details.
  • Maths, statistics, and fundamental machine learning skills help us perform rigorous analyses and reach reliable conclusions from data.
  • Visualisation skills and storytelling allow us to deliver our message and influence the product.
  • Last but not least, computer science and the basics of software engineering are key to our efficiency.

I’ve learned a lot about computer science at university. I’ve tried at least a dozen programming languages (from low-level assembler and CUDA to high-level Java and Scala) and countless tools. My very first job offer was for a backend engineer role. I’ve decided not to pursue this path, but all this knowledge and principles have been beneficial in my analytical career. So, I would like to share the main principles with you in this article.

Code is not for computers. It’s for people

I’ve heard this mantra from software engineers many times. It’s well explained in one of the programming bibles, "Clean Code".

Indeed, the ratio of time spent reading versus writing is well over 10 to 1. We are constantly reading old code as part of the effort to write new code.

In most cases, an engineer prefers more wordy code that is easy to understand to the idiomatic one-liner.

I must confess that I sometimes break this rule and write extra-long pandas one-liners. For example, let’s look at the code below. Do you have any idea what this code is doing?

# ad-hoc only code
df.groupby(['month', 'feature'])[['user_id']].nunique()
  .rename(columns = {'user_id': 'users'})
  .join(df.groupby(['month'])[['user_id']].nunique()
  .rename(columns = {'user_id': 'total_users'})).apply(
    lambda x: 100*x['users']/x['total_users'], axis = 1)
  .reset_index().rename(columns = {0: 'users_share'})
  .pivot(index = 'month', columns = 'feature', values = 'users_share')

Honestly, it’ll probably take me a bit to get up to speed with this code in a month. To make this code more readable, we can split it into steps.

# maintainable code
monthly_features_df = df.groupby(['month', 'feature'])[['user_id']].nunique()
    .rename(columns = {'user_id': 'users'})

monthly_total_df = df.groupby(['month'])[['user_id']].nunique()
    .rename(columns = {'user_id': 'total_users'})

monthly_df = monthly_features_df.join(monthly_total_df).reset_index()
monthly_df['users_share'] = 100*monthly_df.users/monthly_df.total_users

monthly_df.pivot(index = 'month', columns = 'feature', values = 'users_share')

Hopefully, now it’s easier for you to follow the logic and see that this code shows the percentage of customers that use each feature every month. The future me would definitely be way happier to see a code like this and appreciate all the efforts.

Automate repetitive tasks

If you have monotonous tasks that you repeat frequently, I recommend you consider automation. Let me share some examples from my experience that you might find helpful.

The most common way for analysts to automate tasks is to create a dashboard instead of calculating numbers manually every time. Self-serve tools (configurable dashboards where stakeholders can change filters and investigate the data) can save a lot of time and allow us to focus on more sophisticated and impactful research.

If a dashboard is not an option, there are other ways of automation. I was doing weekly reports and sending them to stakeholders via e-mail. After some time, it became a pretty tedious task, and I started to think about automation. At this point, I used the basic tool – cron on a virtual machine. I scheduled a Python script that calculated up-to-date numbers and sent an e-mail.

When you have a script, you just need to add one line to the cron file. For example, the line below will execute analytical_script.py every Monday at 9:10 AM.

10 9 * * 1 python analytical_script.py

Cron is a basic but still sustainable solution. Other tools that can be used to schedule scripts are Airflow, DBT, and Jenkins. You might know Jenkins as a CI/CD (continuous integration & continuous delivery) tool that engineers often use. It might surprise you. It’s customisable enough to execute analytical scripts as well.

If you need even more flexibility, it’s time to think about web applications. In my first team, we didn’t have an A/B test tool, so for a long time, analysts had to analyse each update manually. Finally, we wrote a Flask web application so that engineers could self-serve. Now, there are lightweight solutions for web applications, such as Gradio or Streamlit, that you can learn in a couple of days.

You can find a detailed guide for Gradio in one of my previous articles.

Master your tools

Tools you use every day at work play a significant role in your efficiency and final results. So it’s worth mastering them.

Of course, you can use a default text editor to write code, but most people use IDEs (Integrated Development Environment). You will be spending a lot of your working time on this application, so it’s worth assessing your options.

You can find the most popular IDEs for Python from the JetBrains 2021 survey.

Chart by author, data from the JetBrains survey
Chart by author, data from the JetBrains survey

I usually use Python and Jupyter Notebooks for my day-to-day work. In my opinion, the best IDE for such tasks is JupyterLab. However, I’m trying other options right now to be able to use AI assistants. The benefits of auto-completion, which eliminates lots of boilerplate code, are invaluable for me, so I’m ready to take on switching costs. I encourage you to investigate different options and see what suits your work best.

The other helpful hack is shortcuts. You can do your tasks way faster with shortcuts than with a mouse, and it looks cool. I would start with Googling shortcuts for your IDE since you usually use this tool the most. From my practice, the most valuable commands are creating a new cell in a Notebook, running this cell, deleting it, and converting the cell into markdown.

If you have other tools that you use pretty often (such as Google Sheets or Slack), you can also learn commands for them.

The main trick with learning shortcuts is "practice, practice, practice" – you need to repeat it a hundred times to start doing it automatically. There are even plugins that push you to use shortcuts more (for example, this one from JetBrains).

Last but not least is CLI (command-line interface). It might look intimidating in the beginning, but basic knowledge of CLI usually pays off. I use CLI even to work with GitHub since it gives me a clear understanding of what’s going on exactly.

However, there are situations when it’s almost impossible to avoid using CLI, such as when working on a remote server. To interact confidently with a server, you need to learn less than ten commands. This article can help you gain basic knowledge about CLI.

Manage your environment

Continuing the topic of tools, setting up your environment is always a good idea. I have a Python virtual environment for day-to-day work with all the libraries I usually use.

Creating a new virtual environment is as easy as a couple of lines of code in your terminal (an excellent opportunity to start using CLI).

# creating venv
python -m venv routine_venv

# activating venv
source routine_venv/bin/activate

# installing ALL packages you need 
pip install pandas plotly 

# starting Juputer Notebooks
jupyter notebook

You can start your Jupyter from this environment or use it in your IDE.

It’s a good practice to have a separate environment for big projects. I usually do it only if I need an unusual stack (like PyTorch or yet another new LLM framework) or face some issues with library compatibility.

The other way to save your environment is by using Docker Containers. I use it for something more production-like, like web apps running on the server.

Think about program performance

To tell the truth, analysts often don’t need to think much about performance. When I got my first job in data analytics, my lead shared the practical approach to performance optimisations (and I have been using it ever since). When you’re thinking about performance, consider the total time vs efforts. Suppose I have a MapReduce script that runs for 4 hours. Should I optimise it? It depends.

  • If I need to run it only once or twice, there’s not much sense in spending 1 hour to optimise this script to calculate numbers in just 1 hour.
  • If I plan to run it daily, it’s worth the effort to make it faster and stop wasting computational resources (and money).

Since the majority of my tasks are one-time research, in most cases, I don’t need to optimise my code. However, it’s worth following some basic rules to avoid waiting for hours. Small tricks can lead to tremendous results. Let’s discuss such an example.

Starting from the basics, the cornerstone of performance is big O notation. Simply put, big O notation shows the relation between execution time and the number of elements you work with. So, if my program is O(n), it means that if I increase the amount of data 10 times, execution will be ~10 times longer.

When writing code, it’s worth understanding the complexity of your algorithm and the main data structures. For example, finding out if an element is in a list takes O(n) time, but it only takes O(1) time in a set. Let’s see how it can affect our code.

I have 2 data frames with Q1 and Q2 user transactions, and for each transaction in the Q1 data frame, I would like to understand whether this customer was retained or not. Our data frames are relatively small – around 300-400K rows.

As you can see, performance differs a lot.

  • The first approach is the worst one because, on each iteration (for each row in the Q1 dataset), we calculate the list of unique user_ids. Then, we look up the element in the list with O(n) complexity. This operation takes 13 minutes.
  • The second approach, when we calculate the list first, is a bit better, but it still takes almost 6 minutes.
  • If we pre-calculate a list of user_ids and convert it into the set, we will get the result in a blink of an eye.

As you can see, we can make our code more than 10K times faster with just basic knowledge. It’s a game-changer.

The other general advice is to avoid using plain Python and prefer to use more performant data structures, such as pandas or numpy. These libraries are faster because they use vectorised operations on arrays, which are implemented on C. Usually, numpy would show a bit better performance since pandas is built on top of numpy but has some additional functionality that slows it down a bit.

Don’t forget the DRY principle.

DRY stands for "Don’t Repeat Yourself" and is self-explanatory. This principle praises structured modular code that you can easily reuse.

If you’re copy-pasting a chunk of code for the third time, it’s a sign to think about the code structure and how to encapsulate this logic.

The standard analytical task is data wrangling, and we usually follow the procedural paradigm. So, the most apparent way to structure the code is functions. However, you might follow objective-oriented programming and create classes. In my previous article, I shared an example of the objective-oriented approach to simulations.

The benefits of modular code are better readability, faster development and easier changes. For example, if you want to change your visualisation from a line chart to an area plot, you can do it in one place and re-run your code.

If you have a bunch of functions related to one particular domain, you can create a Python package for it to interact with these functions as with any other Python library. Here’s a detailed guide on how to do it.

Leverage testing

The other topic that is, in my opinion, undervalued in the analytical world is testing. Software engineers often have KPIs on the test coverage, which might also be useful for analysts. However, in many cases, our tests will be related to the data rather than the code itself.

The trick I’ve learned from one of my colleagues is to add tests on the data recency. We have multiple scripts for quarterly and annual reports that we run pretty rarely. So, he added a check to see whether the latest rows in the tables we’re using are after the end of the reporting period (it shows whether the table has been updated). In Python, you can use an assert statement for this.

assert last_record_time >= datetime.date(2023, 5, 31) 

If the condition is fulfilled, then nothing will happen. Otherwise, you will get an AssertionError . It’s a quick and easy check that can help you spot problems early.

The other thing I prefer to validate is sum statistics. For example, if you’re slicing, dicing and transforming your data, it’s worth checking that the overall number of requests and metrics stays the same. Some common mistakes are:

  • duplicates that emerged because of joins,
  • filtered-out None values when you’re using pandas.groupby function,
  • filtered-out dimensions because of inner joins.

Also, I always check data for duplicates. If you expect that each row will represent one user, then the number of rows should be equal to df.user_id.nunique() . If it’s false, something is wrong with your data and needs investigation.

The trickiest and most helpful test is the sense check. Let’s discuss some possible approaches to it.

  • First, I would check whether the results make sense overall. For example, if 1-month retention equals 99% or I got 1 billion customers in Europe, there’s likely a bug in the code.
  • Secondly, I will look for other data sources or previous research on this topic to validate that my results are feasible.
  • If you don’t have other similar research (for example, you’re estimating your potential revenue after launching the product in a new market), I would recommend you compare your numbers to those of other existing segments. For example, if your incremental effect on revenue after launching your product in yet another market equals 5x current income, I would say it’s a bit too optimistic and worth revisiting assumptions.

I hope this mindset will help you achieve more feasible results.

Encourage the team to use Version Control Systems

Engineers use version control systems even for the tiny projects they are working on their own. At the same time, I often see analysts using Google Sheets to store their queries. Since I’m a great proponent and advocate for keeping all the code in the repository, I can’t miss a chance to share my thoughts with you.

Why have I been using a repository for 10+ years of my data career? Here are the main benefits:

  • Reproducibility. Quite often, we need to tweak the previous research (for example, add one more dimension or narrow research down to a specific segment) or just repeat the earlier calculations. If you store all the code in a structured way, you can quickly reproduce your prior work. It usually saves a lot of time.
  • Transparency. Linking code to the results of your research allows your colleagues to understand the methodology to the tiniest detail, which brings more trust and naturally helps to spot bugs or potential improvements.
  • Knowledge sharing. If you have a catalogue that is easy to navigate (or you link your code to Task Trackers), it makes it super-easy for your colleagues to find your code and not start an investigation from scratch.
  • Rolling back. Have you ever been in a situation when your code was working yesterday, but then you changed something, and now it’s completely broken? I’ve been there many times before I started committing my code regularly. Version Control systems allow you to see the whole version history and compare the code or rollback to the previous working version.
  • Collaboration. If you’re working on the code in collaboration with others, you can leverage version control systems to track and merge the changes.

I hope you can see its potential benefits now. Let me briefly share my usual setup to store code:

  • I use git + Github as a version control system, I’m this dinosaur who is still using the command line interface for git (it gives me the soothing feeling of control), but you can use the GitHub app or the functionality of your IDE.
  • Most of my work is research (code, numbers, charts, comments, etc.), so I store 95% of my code as Jupyter Notebooks.
  • I link my code to the Jira tickets. I usually have a tasks folder in my repository and name subfolders as ticket keys (for example, ANALYTICS-42). Then, I place all the files related to the task in this subfolder. With such an approach, I can find code related to (almost) any task in seconds.

There are a bunch of nuances of working with Jupyter Notebooks in GitHub that are worth noting.

First, think about the output. When committing a Jupyter Notebook to the repository, you save input cells (your code or comments) and output. So, it’s worth being conscious about whether you actually want to share the output. It might contain PII or other sensitive data that I wouldn’t advise committing. Also, the output might be pretty big and non-informative, so it will just clutter your repository. When you’re saving 10+ MB Jupyter Notebook with some random data output, all your colleagues will load this data to their computers with the next git pull command.

Charts in output might be especially problematic. We all like excellent interactive Plotly charts. Unfortunately, they are not rendered on GitHub UI, so your colleagues likely won’t see them. To overcome this obstacle, you might switch the output type for Plotly to PNG or JPEG.

import plotly.io as pio
pio.renderers.default = "jpeg"

You can find more details about Plotly renderers in the documentation.

Last but not least, Jupyter Notebooks diffs are usually tricky. You would often like to understand the difference between 2 versions of the code. However, the default GitHub view won’t give you much helpful info because there is too much clutter due to changes in notebook metadata (like in the example below).

Actually, GitHub has almost solved this issue. A rich diffs functionality in feature preview can make your life way easier – you just need to switch it on in settings.

With this feature, we can easily see that there were just a couple of changes. I’ve changed the default renderer and parameters for retention curves (so a chart has been updated as well).

Ask for a code review

Engineers do peer reviews for (almost) all changes to the code. This process allows one to spot bugs early, stop bad actors or effectively share knowledge in the team.

Of course, it’s not a silver bullet: reviewers can miss bugs, or a bad actor might introduce a breach into the popular open-source project. For example, there was quite a scary story of how a backdoor was planted into a compression tool widely used in popular Linux distributions.

However, there is evidence that code review actually helps. McConnell shares the following stats in his iconic book "Code Complete".

… software testing alone has limited effectiveness – the average defect detection rate is only 25 percent for unit testing, 35 percent for function testing, and 45 percent for integration testing. In contrast, the average effectiveness of design and code inspections are 55 and 60 percent.

Despite all these benefits, analysts often don’t use code review at all. I can understand why it might be challenging:

  • Analytical teams are usually smaller, and spending limited resources on double-checking might not sound reasonable.
  • Quite often, analysts work in different domains, and you might end up being the only person who knows this domain well enough to do a code review.

However, I really encourage you to do a code review, at least for critical things to mitigate risks. Here are the cases when I ask colleagues to double-check my code and assumptions:

  • When I’m using data in a new domain, it’s always a good idea to ask an expert to review the assumptions used;
  • All the tasks related to customer communications or interventions since errors in such data might lead to significant impact (for example, we’ve communicated wrong information to customers or deactivated wrong people);
  • High-stakes decisions: if you plan to invest six months of the team’s effort into the project, it’s worth double- and triple-checking;
  • When results are unexpected: the first hypothesis to test when I see surprising results is to check for an error in code.

Of course, it’s not an exhaustive list, but I hope you can see my reasoning and use common sense to define when to reach out for code review.

Stay up-to-date

The famous Lewis Caroll quote represents the current state of the tech domain quite well.

… it takes all the running you can do, to keep in the same place. If you want to get somewhere else, you must run at least twice as fast as that.

Our field is constantly evolving: new papers are published every day, libraries are updated, new tools emerge and so on. It’s the same story for software engineers, data analysts, data scientists, etc.

There are so many sources of information right now that there’s no problem to find it:

  • weekly e-mails from Towards Data Science and some other subscriptions,
  • following experts on LinkedIn and X (former Twitter),
  • subscribing to e-mail updates for the tools and libraries I use,
  • attending local meet-ups.

A bit more tricky is to avoid being drowned by all the information. I try to focus on one thing at a time to prevent too much distraction.

Summary

That’s it with the software engineering practices that can be helpful for analysts. Let me quickly recap them all here:

  • Code is not for computers. It’s for people.
  • Automate repetitive tasks.
  • Master your tools.
  • Manage your environment.
  • Think about program performance.
  • Don’t forget the DRY principle.
  • Leverage testing.
  • Encourage the team to use Version Control Systems.
  • Ask for a code review.
  • Stay up-to-date.

Data analytics combines skills from different domains, so I believe we can benefit greatly from learning the Best Practices of software engineers, product managers, designers, etc. By adopting the tried-and-true techniques of our colleagues, we can improve our effectiveness and efficiency. I highly encourage you to explore these adjacent domains as well.

Thank you a lot for reading this article. I hope this article was insightful for you. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

Acknowledgements

I can’t miss a chance to express my heartfelt thanks to my partner, who has been sharing his engineering wisdom with me for ages and has reviewed all my articles.

The post From Code to Insights: Software Engineering Best Practices for Data Analysts appeared first on Towards Data Science.

]]>
Practical Computer Simulations for Product Analysts https://towardsdatascience.com/practical-computer-simulations-for-product-analysts-fe61e2b577f5/ Fri, 24 May 2024 07:30:04 +0000 https://towardsdatascience.com/practical-computer-simulations-for-product-analysts-fe61e2b577f5/ Part 3: Modelling Ops queues

The post Practical Computer Simulations for Product Analysts appeared first on Towards Data Science.

]]>
Image by DALL-E 3
Image by DALL-E 3

Today, I would like to show you an example of the discrete-event simulation approach. We will model the Customer Support team and decide what strategy to use to improve its performance. But first, let me share a bit of my personal story.

I first learned about discrete simulations at university. One of my subjects was Queueing theory, and to get a final grade for it, I had to implement the airport simulation and calculate some KPIs. Unfortunately, I missed all the seminars because I was already working full-time, so I had no idea about the theory behind this topic and how to approach it.

I was determined to get an excellent mark, so I found a book, read it, understood the basics and spent a couple of evenings on implementation. It was pretty challenging since I hadn’t been coding for some time, but I figured it out and got my A grade.

At this point (as often happens with students), I had a feeling that this information wouldn’t be helpful for my future work. However, later, I realised that many analytical tasks can be solved with this approach. So, I would like to share it with you.

One of the most apparent use cases for agent-based simulations is Operational analytics. Most products have customer support where clients can get help. A CS team often looks at such metrics as:

  • average resolution time – how much time passed from the customer reaching out to CS and getting the first answer,
  • size of the queue that shows how many tasks we have in a backlog right now.

Without a proper model, it may be tricky to understand how our changes (i.e. introducing night shifts or just increasing the number of agents) will affect the KPIs. Simulations will help us do it.

So, let’s not waste our time and move on.

Basics of simulations and modelling

Let’s start from the very beginning. We will be modelling the system. The system is a collection of entities (for example, people, servers or even mechanical tools) that interact with each other to achieve some logical goal (i.e. answering a customer question or passing border control in an airport).

You could define the system with the needed granularity level, depending on your research goal. For example, in our case, we would like to investigate how the changes to agents’ efficiency and schedules could affect average CS ticket resolution time. So, the system will be just a set of agents. However, if we would like to model the possibility of outsourcing some tickets to different outsourcing companies, we will need to include these partners in our model.

The system is described by a set of variables – for example, the number of tickets in a queue or the number of agents working at the moment in time. These variables define the system state.

There are two types of systems:

  • discrete when the system state changes instantaneously, for example, the new ticket has been added to a queue or an agent has finished their shift.
  • continuous – when the system is constantly evolving. One such example is a flying plane, in which coordinates, velocity, height, and other parameters change all the time during flight.

For our task, we can treat the system as discrete and use the discrete-event simulation approach. It’s a case when the system can change at only a countable number of points in time. These time points are where events occur and instantly change the system state.

So, the whole approach is based on events. We will generate and process events one by one to simulate how the system works. We can use the concept of a timeline to structure events.

Since this process is dynamic, we need to keep track of the current value of simulated time and be able to advance it from one value to another. The variable in a simulation model that shows the current time is often called the simulation clock.

We also need a mechanism to advance simulated time. There are two approaches to advance time:

  • next-event time advance – we are moving from one event timestamp to the next one,
  • fixed-increment time advance – we select the period, for example, 1 minute, and shift clocks each time for this period.

I think the first approach is easier to understand, implement and debug. So, I will stick to it for this article.

Let’s review a simple example to understand how it works. We will discuss a simplified case of the CS tickets queue.

We start the simulation, initialising the simulation clock. Sometimes, people use zero as the initial value. I prefer to use real-life data and the actual date times.

Here’s the initial state of our system. We have two events on our timeline related to two incoming customer requests.

The next step is to advance the simulation clock to the first event on our timeline – the customer request at 9:15.

It’s time to process this event. We should find an agent to work on this request, assign the request to them, and generate an event to finish the task. Events are the main drivers of our simulation, so it’s okay if one event creates another one.

Looking at the updated timeline, we can see that the most imminent event is not the second customer request but the completion of the first task.

So, we need to advance our clock to 9:30 and process the next event. The completion of the request won’t create new events, so after that, we will move to the second customer request.

We will repeat this process of moving from one event to another until the end of the simulation.

To avoid never-ending processes, we need to define the stopping criteria. In this case, we can use the following logic: if no more events are on the timeline, we should stop the simulation. In this simplified example, our simulation will stop after finishing the second task.

We’ve discussed the theory of discrete event simulations and understood how it works. Now, it’s time to practice and implement this approach in code.

The program architecture

Objective-oriented programming

In my day-to-day job, I usually use a procedural Programming paradigm. I create functions for some repetitive tasks, but rather than that, my code is quite linear. It’s pretty standard approach for data-wrangling tasks.

In this example, we would use Objective-Oriented Programming. So, let’s spend some time revising this topic if you haven’t used classes in Python before or need a refresher.

OOP is based on the concept of objects. Objects consist of data (some features that are called attributes) and actions (functions or methods). The whole program describes the interactions between different objects. For example, if we have an object representing a CS agent, it can have the following properties:

  • attributes: name, date when an agent started working, average time they spend on tasks or current status ("out of office", "working on task" or "free").
  • methods: return the name, update the status or start processing a customer request.

To represent such an object, we can use Python classes. Let’s write a simple class for a CS agent.

class CSAgent:
  # initialising class
  def __init__(self, name, average_handling_time):
      # saving parameters mentioned during object creation
      self.name = name  
      self.average_handling_time = average_handling_time
      # specifying constant value
      self.role = 'CS agent'
      print('Created %s with name %s' % (self.role, self.name))

  def get_name(self):
    return self.name

  def get_handling_time(self):
    return self.average_handling_time

  def update_handling_time(self, average_handling_time):
    print('Updating time from %.2f to %.2f' % (self.average_handling_time, 
      average_handling_time))
    self.average_handling_time = average_handling_time

This class defines each agent’s name, average handling time, and role. I’ve also added a couple of functions that can return internal variables following the incapsulation pattern. Also, we have the update_handling_time function that allows us to update the agent’s performance.

We’ve created a class (an object that explains any kind of CS agent). Let’s make an instance of the object – the agent John Doe.

john_agent = CSAgent('John Doe', 12.3)
# Created CS agent with name John Doe

When we created an instance of the class, the function __init__ was executed. We can use __dict__ property to present class fields as a dictionary. It often can be handy, for example, if you want to convert a list of objects into a data frame.

print(john_agent.__dict__)
# {'name': 'John Doe', 'average_handling_time': 12.3, 'role': 'CS agent'}

We can try to execute a method and update the agent’s performance.

john_agent.update_handling_time(5.4)
# Updating time from 12.30 to 5.40

print(john_agent.get_handling_time())
# 5.4

One of the fundamental concepts of OOP that we will use today is inheritance. Inheritance allows us to have a high-level ancestor class and use its features in the descendant classes. Imagine we want to have not only CS agents but also KYC agents. We can create a high-level Agent class with common functionality and define it only once for both KYC and CS agents.

class Agent:
  # initialising class
  def __init__(self, name, average_handling_time, role):
    # saving parameters mentioned during object creation
    self.name = name  
    self.average_handling_time = average_handling_time
    self.role = role
    print('Created %s with name %s' % (self.role, self.name))

  def get_name(self):
    return self.name

  def get_handling_time(self):
    return self.average_handling_time

  def update_handling_time(self, average_handling_time):
    print('Updating time from %.2f to %.2f' % (self.average_handling_time, 
      average_handling_time))
    self.average_handling_time = average_handling_time

Now, we can create separate classes for these agent types and define slightly different __init__ and get_job_description functions.

class KYCAgent(Agent):
  def __init__(self, name, average_handling_time):
    super().__init__(name, average_handling_time, 'KYC agent')

  def get_job_description(self):
    return 'KYC (Know Your Customer) agents help to verify documents'

class CSAgent(Agent):
  def __init__(self, name, average_handling_time):
    super().__init__(name, average_handling_time, 'CS agent')

  def get_job_description(self):
    return 'CS (Customer Support) answer customer questions and help resolving their problems'

To specify inheritance, we mentioned the base class in brackets after the current class name. With super() , we can call the base class methods, for example, __init__ to create an object with a custom role value.

Let’s create objects and check whether they work as expected.

marie_agent = KYCAgent('Marie', 25)
max_agent = CSAgent('Max', 10)

print(marie_agent.__dict__)
# {'name': 'Marie', 'average_handling_time': 25, 'role': 'KYC agent'}
print(max_agent.__dict__)
# {'name': 'Max', 'average_handling_time': 10, 'role': 'CS agent'}

Let’s update Marie’s handling time. Even though we haven’t implemented this function in the KYCAgent class, it uses the implementation from the base class and works quite well.

marie_agent.update_handling_time(22.5)
# Updating time from 25.00 to 22.50

We can also call the methods we defined in the classes.

print(marie_agent.get_job_description())
# KYC (Know Your Customer) agents help to verify documents

print(max_agent.get_job_description())
# CS (Customer Support) answer customer questions and help resolving their problems

So, we’ve covered the basics of the Objective-oriented paradigm and Python classes. I hope it was a helpful refresher.

Now, it’s time to return to our task and the model we need for our Simulation.

Architecture: classes

If you haven’t used OOP a lot before, switching your mindset from procedures to objects might be challenging. It takes some time to make this mindset shift.

One of the life hacks is to use real-world analogies (i.e. it’s pretty clear that an agent is an object with some features and actions).

Also, don’t be afraid to make a mistake. There are better or worse program architectures: some will be easier to read and support over time. However, there are a lot of debates about the best practices, even among mature software engineers, so I wouldn’t bother trying to make it perfect too much for analytical ad-hoc research.

Let’s think about what objects we need in our simulation:

  • System – the most high-level concept we have in our task. The system will represent the current state and execute the simulation.
  • As we discussed before, the system is a collection of entities. So, the next object we need is Agent . This class will describe agents working on tasks.
  • Each agent will have its schedule: hours when this agent is working, so I’ve isolated it into a separate class Schedule.
  • Our agents will be working on customer requests. So, it’s a no-brainer— we need to represent them in our system. Also, we will store a list of processed requests in the System object to get the final stats after the simulation.
  • If no free agent picks up a new customer request, it will be put into a queue. So, we will have a RequestQueue as an object to store all customer requests with the FIFO logic (First In, First Out).
  • The following important concept is TimeLine that represents the set of events we need to process ordered by time.
  • TimeLine will include events, so we will also create a class Event for them. Since we will have a bunch of different event types that we need to process differently, we can leverage the OOP inheritance. We will discuss event types in more detail in the next section.

That’s it. I’ve put all the classes and links between them into a diagram to clarify it. I use such charts to have a high-level view of the system before starting the implementation – it helps to think about the architecture early on.

As you might have noticed, the diagram is not super detailed. For example, it doesn’t include all field names and methods. It’s intentional. This schema will be used as a helicopter view to guide the development. So, I don’t want to spend too much time writing down all the field and method names because these details might change during the implementation.

Architecture: event types

We’ve covered the program architecture, and now it’s time to think about the main drivers of our simulation – events.

Let’s discuss what events we need to generate to keep our system working.

  • The event I will start with is the "Agent Ready" event. It shows that an agent starts their work and is ready to pick up a task (if we have any waiting in the queue).
  • We need to know when agents start working. These working hours can depend on an agent and the day of the week. Potentially, we might even want to change the schedules during the simulation. It’s pretty challenging to create all "Agent Ready" events when we initialise the system (especially since we don’t know how much time we need to finish the simulation). So, I propose a recurrent "Plan Agents Schedule" event to create ready-to-work events for the next day.
  • The other essential event we need is a "New Customer Request" – an event that shows that we got a new CS contact, and we need to either start working on it or put it in a queue.
  • The last event is "Agent Finished Task", which shows that the agent finished the task he was working on and is potentially ready to pick up a new task.

That’s it. These four events are enough to run the whole simulation.

Similar to classes, there are no right or wrong answers for system modelling. You might use a slightly different set of events. For example, you can add a "Start Task" event to have it explicitly.

Implementation

You can find the full implementation on GitHub.

We’ve defined the high-level structure of our solution, so we are ready to start implementing it. Let’s start with the heart of our simulation – the system class.

Initialising the system

Let’s start with the __init__ method for the system class.

First, let’s think about the parameters we would like to specify for the simulation:

  • agents – set of agents that will be working in the CS team,
  • queue – the current queue of customer requests (if we have any),
  • initial_date – since we agreed to use the actual timestamps instead of relative ones, I will specify the date when we start simulations,
  • logging – flag that defines whether we would like to print some info for debugging,
  • customer_requests_df – data frame with information about the set of customer requests we would like to process.

Besides input parameters, we will also create the following internal fields:

  • current_time – the simulation clock that we will initialise as 00:00:00 of the initial date specified,
  • timeline object that we will use to define the order of events,
  • processed_request – an empty list where we will store the processed customer requests to get the data after simulation.

It’s time to take the necessary actions to initialise a system. There are only two steps left:

  • Plan agents work for the first day. I’ll generate and process a corresponding event with an initial timestamp.
  • Load customer requests by adding corresponding "New Customer Request" events to the timeline.

Here’s the code that does all these actions to initialise the system.

class System:
  def __init__(self, agents, queue, initial_date,
  customer_requests_df, logging = True):
    initial_time = datetime.datetime(initial_date.year, initial_date.month, 
      initial_date.day, 0, 0, 0)
    self.agents = agents
    self.queue = RequestQueue(queue)
    self.logging = logging
    self.current_time = initial_time

    self._timeline = TimeLine()
    self.processed_requests = []

    initial_event = PlanScheduleEvent('plan_agents_schedule', initial_time)
    initial_event.process(self)
    self.load_customer_request_events(customer_requests_df)

It’s not working yet since it has links to non-implemented classes and methods, but we will cover it all one by one.

Timeline

Let’s start with the classes we used in the system definition. The first one is TimeLine . The only field it has is the list of events. Also, it implements a bunch of methods:

  • adding events (and ensuring that they are ordered chronologically),
  • returning the next event and deleting it from the list,
  • telling how many events are left.
class TimeLine:
  def __init__(self):
    self.events = []

  def add_event(self, event:Event):
    self.events.append(event)
    self.events.sort(key = lambda x: x.time)

  def get_next_item(self):
    if len(self.events) == 0:
      return None
    return self.events.pop(0)

  def get_remaining_events(self):
    return len(self.events)

Customer requests queue

The other class we used in initialisation is RequestQueue.

There are no surprises: the request queue consists of customer requests. Let’s start with this building block. We know each request’s creation time and how much time an agent will need to work on it.

class CustomerRequest:
  def __init__(self, id, handling_time_secs, creation_time):
    self.id = id
    self.handling_time_secs = handling_time_secs
    self.creation_time = creation_time

  def __str__(self):
    return f'Customer Request {self.id}: {self.creation_time.strftime("%Y-%m-%d %H:%M:%S")}'

It’s a simple data class that contains only parameters. The only new thing here is that I’ve overridden the __str__ method to change the output of a print function. It’s pretty handy for debugging. You can compare it yourself.

test_object = CustomerRequest(1, 600, datetime.datetime(2024, 5, 1, 9, 42, 1))
# without defining __str__
print(test_object)
# <__main__.CustomerRequest object at 0x280209130>

# with custom __str__
print(test_object)
# Customer Request 1: 2024-05-01 09:42:01

Now, we can move on to the requests queue. Similarly to the timeline, we’ve implemented methods to add new requests, calculate requests in the queue and get the subsequent request from the queue.

class RequestQueue:
  def __init__(self, queue = None):
    if queue is None:
      self.requests = []
    else: 
      self.requests = queue

  def get_requests_in_queue(self):
    return len(self.requests)

  def add_request(self, request):
    self.requests.append(request)

  def get_next_item(self):
    if len(self.requests) == 0:
      return None
    return self.requests.pop(0)

Agents

The other thing we need to initialise the system is agents. First, each agent has a schedule – a period when they are working depending on a weekday.

class Schedule:
  def __init__(self, time_periods):
    self.time_periods = time_periods

  def is_within_working_hours(self, dt):
    weekday = dt.strftime('%A')

    if weekday not in self.time_periods:
      return False

    hour = dt.hour
    time_periods = self.time_periods[weekday]
    for period in time_periods:
      if (hour >= period[0]) and (hour < period[1]):
        return True
    return False

The only method we have for a schedule is whether at the specified moment the agent is working or not.

Let’s define the agent class. Each agent will have the following attributes:

  • id and name – primarily for logging and debugging purposes,
  • schedule – the agent’s schedule object we’ve just defined,
  • request_in_work – link to customer request object that shows whether an agent is occupied right now or not.
  • effectiveness – the coefficient that shows how efficient the agent is compared to the expected time to solve the particular task.

We have the following methods implemented for agents:

  • understanding whether they can take on a new task (whether they are free and still working),
  • start and finish processing the customer request.
class Agent:
  def __init__(self, id, name, schedule, effectiveness = 1):
    self.id = id
    self.schedule = schedule
    self.name = name
    self.request_in_work = None
    self.effectiveness = effectiveness

  def is_ready_for_task(self, dt):
    if (self.request_in_work is None) and (self.schedule.is_within_working_hours(dt)):
      return True
    return False

  def start_task(self, customer_request):
    self.request_in_work = customer_request
    customer_request.handling_time_secs = int(round(self.effectiveness * customer_request.handling_time_secs))

  def finish_task(self):
    self.request_in_work = None

Loading initial customer requests to the timeline

The only thing we are missing from the system __init__ function (besides the events processing that we will discuss in detail a bit later) is load_customer_request_events function implementation. It’s pretty straightforward. We just need to add it to our System class.

class System:
  def load_customer_request_events(self, df):
    # filter requests before the start of simulation
    filt_df = df[df.creation_time >= self.current_time]
    if filt_df.shape[0] != df.shape[0]:
      if self.logging:
        print('Attention: %d requests have been filtered out since they are outdated' % (df.shape[0] - filt_df.shape[0]))

    # create new customer request events for each record
    for rec in filt_df.sort_values('creation_time').to_dict('records'):
      customer_request = CustomerRequest(rec['id'], rec['handling_time_secs'], 
        rec['creation_time'])

      self.add_event(NewCustomerRequestEvent(
        'new_customer_request', rec['creation_time'],
         customer_request
      ))

Cool, we’ve figured out the primary classes. So, let’s move on to the implementation of the events.

Processing events

As discussed, I will use the inheritance approach and create an Event class. For now, it implements only __init__ and __str__ functions, but potentially, it can help us provide additional functionality for all events.

class Event:
  def __init__(self, event_type, time):
    self.type = event_type
    self.time = time

  def __str__(self):
    if self.type == 'agent_ready_for_task':
      return '%s (%s) - %s' % (self.type, self.agent.name, self.time)
    return '%s - %s' % (self.type, self.time)

Then, I implement a separate subclass for each event type that might have a bit different initialisation. For example, for the AgentReady event, we also have an Agent object. More than that, each Event class implements process method that takes system as an input.


class AgentReadyEvent(Event):
  def __init__(self, event_type, time, agent):
    super().__init__(event_type, time)
    self.agent = agent

  def process(self, system: System):
    # get next request from the queue
    next_customer_request = system.queue.get_next_item()

    # start processing request if we had some
    if next_customer_request is not None:
      self.agent.start_task(next_customer_request)
      next_customer_request.start_time = system.current_time
      next_customer_request.agent_name = self.agent.name
      next_customer_request.agent_id = self.agent.id

      if system.logging:
        print('<%s> Agent %s started to work on request %d' % (system.current_time, 
          self.agent.name, next_customer_request.id))

      # schedule finish processing event
      system.add_event(FinishCustomerRequestEvent('finish_handling_request', 
        system.current_time + datetime.timedelta(seconds = next_customer_request.handling_time_secs), 
        next_customer_request, self.agent)) 

class PlanScheduleEvent(Event):
  def __init__(self, event_type, time):
    super().__init__(event_type, time)

  def process(self, system: System):     
    if system.logging:
        print('<%s> Scheeduled agents for today' % (system.current_time))
    current_weekday = system.current_time.strftime('%A')

    # create agent ready events for all agents working on this weekday
    for agent in system.agents:
      if current_weekday not in agent.schedule.time_periods:
        continue

      for time_periods in agent.schedule.time_periods[current_weekday]:
        system.add_event(AgentReadyEvent('agent_ready_for_task', 
          datetime.datetime(system.current_time.year, system.current_time.month, 
          system.current_time.day, time_periods[0], 0, 0), 
          agent))

    # schedule next planning
    system.add_event(PlanScheduleEvent('plan_agents_schedule', system.current_time + datetime.timedelta(days = 1)))

class FinishCustomerRequestEvent(Event):
  def __init__(self, event_type, time, customer_request, agent):
    super().__init__(event_type, time)
    self.customer_request = customer_request
    self.agent = agent

  def process(self, system):
    self.agent.finish_task()
    # log finish time
    self.customer_request.finish_time = system.current_time
    # save processed request
    system.processed_requests.append(self.customer_request)

    if system.logging:
      print('<%s> Agent %s finished request %d' % (system.current_time, self.agent.name, self.customer_request.id))

    # pick up the next request if agent continue working and we have something in the queue
    if self.agent.is_ready_for_task(system.current_time):
      next_customer_request = system.queue.get_next_item()
      if next_customer_request is not None:
        self.agent.start_task(next_customer_request)
        next_customer_request.start_time = system.current_time
        next_customer_request.agent_name = self.agent.name
        next_customer_request.agent_id = self.agent.id

        if system.logging:
            print('<%s> Agent %s started to work on request %d' % (system.current_time, 
              self.agent.name, next_customer_request.id))
        system.add_event(FinishCustomerRequestEvent('finish_handling_request', 
          system.current_time + datetime.timedelta(seconds = next_customer_request.handling_time_secs), 
          next_customer_request, self.agent)) 

class NewCustomerRequestEvent(Event):
  def __init__(self, event_type, time, customer_request):
    super().__init__(event_type, time)
    self.customer_request = customer_request

  def process(self, system: System):
    # check whether we have a free agent
    assigned_agent = system.get_free_agent(self.customer_request)

    # if not put request in a queue
    if assigned_agent is None:
      system.queue.add_request(self.customer_request)
      if system.logging:
          print('<%s> Request %d put in a queue' % (system.current_time, self.customer_request.id))
    # if yes, start processing it
    else:
      assigned_agent.start_task(self.customer_request)
      self.customer_request.start_time = system.current_time
      self.customer_request.agent_name = assigned_agent.name
      self.customer_request.agent_id = assigned_agent.id
      if system.logging:
          print('<%s> Agent %s started to work on request %d' % (system.current_time, assigned_agent.name, self.customer_request.id))
      system.add_event(FinishCustomerRequestEvent('finish_handling_request', 
        system.current_time + datetime.timedelta(seconds = self.customer_request.handling_time_secs), 
        self.customer_request, assigned_agent))

That’s actually it with the events processing business logic. The only bit we need to finish is to put everything together to run our simulation.

Putting all together in the system class

As we discussed, the System class will be in charge of running the simulations. So, we will put the remaining nuts and bolts there.

Here’s the remaining code. Let me briefly walk you through the main points:

  • is_simulation_finished defines the stopping criteria for our simulation – no requests are in the queue, and no events are in the timeline.
  • process_next_event gets the next event from the timeline and executes process for it. There’s a slight nuance here: we might end up in a situation where our simulation never ends because of recurring "Plan Agents Schedule" events. That’s why, in case of processing such an event type, I check whether there are any other events in the timeline and if not, I don’t process it since we don’t need to schedule agents anymore.
  • run_simulation is the function that rules our world, but since we have quite a decent architecture, it’s a couple of lines: we check whether we can finish the simulation, and if not, we process the next event.
class System:
  # defines the stopping criteria
  def is_simulation_finished(self):
    if self.queue.get_requests_in_queue() > 0: 
      return False
    if self._timeline.get_remaining_events() > 0:
      return False
    return True

  # wrappers for timeline methods to incapsulate this logic
  def add_event(self, event):
    self._timeline.add_event(event)

  def get_next_event(self):
    return self._timeline.get_next_item()

  # returns free agent if we have one
  def get_free_agent(self, customer_request):
    for agent in self.agents:
      if agent.is_ready_for_task(self.current_time):
        return agent

  # finds and processes the next event
  def process_next_event(self):
    event = self.get_next_event()
    if self.logging:
      print('# Processing event: ' + str(event))
    if (event.type == 'plan_agents_schedule') and self.is_simulation_finished():
      if self.logging:
        print("FINISH")
    else:
      self.current_time = event.time        
      event.process(self)

  # main function
  def run_simulation(self):
    while not self.is_simulation_finished():
      self.process_next_event()

It was a long journey, but we’ve done it. Amazing job! Now, we have all the logic we need. Let’s move on to the funny part and use our model for analysis.

You can find the full implementation on GitHub.

Analysis

I will use a synthetic Customer Requests dataset to simulate different Ops setups.

First of all, let’s run our system and look at metrics. I will start with 15 agents who are working regular hours.

# initialising agents
regular_work_week = Schedule(
  {
    'Monday': [(9, 12), (13, 18)],
    'Tuesday': [(9, 12), (13, 18)],
    'Wednesday': [(9, 12), (13, 18)],
    'Thursday': [(9, 12), (13, 18)],
    'Friday': [(9, 12), (13, 18)]
  }
)

agents = []
for id in range(15):
  agents.append(Agent(id + 1, 'Agent %s' % id, regular_work_week))

# inital date
system_initial_date = datetime.date(2024, 4, 8)

# initialising the system 
system = System(agents, [], system_initial_date, backlog_df, logging = False)

# running the simulation 
system.run_simulation()

As a result of the execution, we got all the stats in system.processed_requests. Let’s put together a couple of helper functions to analyse results easier.

# convert results to data frame and calculate timings
def get_processed_results(system):
  processed_requests_df = pd.DataFrame(list(map(lambda x: x.__dict__, system.processed_requests)))
  processed_requests_df = processed_requests_df.sort_values('creation_time')
  processed_requests_df['creation_time_hour'] = processed_requests_df.creation_time.map(
      lambda x: x.strftime('%Y-%m-%d %H:00:00')
  )

  processed_requests_df['resolution_time_secs'] = list(map(
      lambda x, y: int(x.strftime('%s')) - int(y.strftime('%s')),
      processed_requests_df.finish_time,
      processed_requests_df.creation_time
  ))

  processed_requests_df['waiting_time_secs'] = processed_requests_df.resolution_time_secs - processed_requests_df.handling_time_secs

  processed_requests_df['waiting_time_mins'] = processed_requests_df['waiting_time_secs']/60
  processed_requests_df['handling_time_mins'] = processed_requests_df.handling_time_secs/60
  processed_requests_df['resolution_time_mins'] = processed_requests_df.resolution_time_secs/60
  return processed_requests_df

# calculating queue size with 5 mins granularity
def get_queue_stats(processed_requests_df):
  queue_stats = []

  current_time = datetime.datetime(system_initial_date.year, system_initial_date.month, system_initial_date.day, 0, 0, 0)
  while current_time <= processed_requests_df.creation_time.max() + datetime.timedelta(seconds = 300):
    queue_size = processed_requests_df[(processed_requests_df.creation_time <= current_time) &amp; (processed_requests_df.start_time > current_time)].shape[0]
    queue_stats.append(
      {
          'time': current_time,
          'queue_size': queue_size
      }
    )

    current_time = current_time + datetime.timedelta(seconds = 300)

  return pd.DataFrame(queue_stats)

Also, let’s make a couple of charts and calculate weekly metrics.

def analyse_results(system, show_charts = True):
  processed_requests_df = get_processed_results(system)
  queue_stats_df = get_queue_stats(processed_requests_df)

  stats_df = processed_requests_df.groupby('creation_time_hour').aggregate(
      {'id': 'count', 'handling_time_mins': 'mean', 'resolution_time_mins': 'mean',
       'waiting_time_mins': 'mean'}
  )

  if show_charts:
    fig = px.line(stats_df[['id']], 
      labels = {'value': 'requests', 'creation_time_hour': 'request creation time'},
      title = '<b>Number of requests created</b>')
    fig.update_layout(showlegend = False)
    fig.show()

    fig = px.line(stats_df[['waiting_time_mins', 'handling_time_mins', 'resolution_time_mins']], 
      labels = {'value': 'time in mins', 'creation_time_hour': 'request creation time'},
      title = '<b>Resolution time</b>')
    fig.show()

    fig = px.line(queue_stats_df.set_index('time'), 
      labels = {'value': 'number of requests in queue'},
      title = '<b>Queue size</b>')
    fig.update_layout(showlegend = False)
    fig.show()

  processed_requests_df['period'] = processed_requests_df.creation_time.map(
      lambda x: (x - datetime.timedelta(x.weekday())).strftime('%Y-%m-%d')
  )
  queue_stats_df['period'] = queue_stats_df['time'].map(
      lambda x: (x - datetime.timedelta(x.weekday())).strftime('%Y-%m-%d')
  )

  period_stats_df = processed_requests_df.groupby('period')
    .aggregate({'id': 'count', 'handling_time_mins': 'mean',
      'waiting_time_mins': 'mean', 
      'resolution_time_mins': 'mean'})
    .join(queue_stats_df.groupby('period')[['queue_size']].mean())

  return period_stats_df

# execution
analyse_results(system)

Now, we can use this function to analyse the simulation results. Apparently, 15 agents are not enough for our product since, after three weeks, we have 4K+ requests in a queue and an average resolution time of around ten days. Customers would be very unhappy with our service if we had just 15 agents.

Let’s find out how many agents we need to be able to cope with the demand. We can run a bunch of simulations with the different number of agents and compare results.

tmp_dfs = []

for num_agents in tqdm.tqdm(range(15, 105, 5)):
  agents = []
  for id in range(num_agents):
    agents.append(Agent(id + 1, 'Agent %s' % id, regular_work_week))
  system = System(agents, [], system_initial_date, backlog_df, logging = False)
  system.run_simulation()

  tmp_df = analyse_results(system, show_charts = False)
  tmp_df['num_agents'] = num_agents
  tmp_dfs.append(tmp_df)

We can see that from ~25–30 agents, metrics for different weeks are roughly the same, so there’s enough capacity to handle incoming requests and queue is not growing week after week.

If we model the situation when we have 30 agents, we can see that the queue is empty from 13:50 till the end of the working day from Tuesday to Friday. Agents spend Monday processing the huge queue we are gathering during weekends.

With such a setup, the average resolution time is 500.67 minutes, and the average queue length is 259.39.

Let’s try to think about the possible improvements for our Operations team:

  • we can hire another five agents,
  • we can start leveraging LLMs and reduce handling time by 30%,
  • we can shift agents’ schedules to provide coverage during weekends and late hours.

Since we now have a model, we can easily estimate all the opportunities and pick the most feasible one.

The first two approaches are straightforward. Let’s discuss how we can shift the agents’ schedules. All our agents are working from Monday to Friday from 9 to 18. Let’s try to make their coverage a little bit more equally distributed.

First, we can cover later and earlier hours, splitting agents into two groups. We will have agents working from 7 to 16 and from 11 to 20.

Second, we can split them across working days more evenly. I used quite a straightforward approach.

In reality, you can go even further and allocate fewer agents on weekends since we have way less demand. It can improve your metrics even further. However, the additional effect will be marginal.

If we run simulations for all these scenarios, surprisingly, we will see that KPIs will be way better if we just change agents’ schedules. If we hire five more people or improve agents’ performance by 30%, we won’t achieve such a significant improvement.

Let’s see how changes in agents’ schedules affect our KPIs. Resolution time grows only for cases outside working hours (from 20 to 7), and queue size never reaches 200 cases.

That’s an excellent result. Our simulation model has helped us prioritise operational changes instead of hiring more people or investing in LLM tool development.

We’ve discussed the basics of this approach in this article. If you want to dig deeper and use it in practice, here are a couple more suggestions that might be useful:

  • Before starting to use such models in production, it’s worth testing them. The most straightforward way is to model your current situation and compare the main KPIs. If they differ a lot, then your system doesn’t represent the real world well enough, and you need to make it more accurate before using it for decision-making.
  • The current metrics are customer-focused. I’ve used average resolution time as the primary KPI to make decisions. In business, we also care about costs. So, it’s worth looking at this task from an operational perspective as well, i.e. measure the percentage of time when agents don’t have tasks to work on (which means we are paying them for nothing).
  • In real life, there might be spikes (i.e. the number of customer requests has doubled because of a bug in your product), so I recommend you use such models to ensure that your CS team can handle such situations.
  • Last but not least, the model I’ve used was entirely deterministic (it returns the same result on every run), because handling time was defined for each customer request. To better understand metrics variability, you can specify the distribution of handling times (depending on the task type, day of the week, etc.) for each agent and get handling time from this distribution at each iteration. Then, you can run the simulation multiple times and calculate the confidence intervals of your metrics.

Summary

So, let’s briefly summarise the main points we’ve discussed today:

  • We’ve learned the basics of the discrete-event simulation approach that helps to model discrete systems with a countable number of events.
  • We’ve revised the object-oriented programming and classes in Python since this paradigm is more suitable for this task than the common procedural code data analysts usually use.
  • We’ve built the model of the CS team and were able to estimate the impact of different potential improvements on our KPIs (resolution time and queue size).

Thank you a lot for reading this article. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

The post Practical Computer Simulations for Product Analysts appeared first on Towards Data Science.

]]>
Practical Computer Simulations for Product Analysts https://towardsdatascience.com/practical-computer-simulations-for-product-analysts-4d3a17957f64/ Tue, 30 Apr 2024 06:57:11 +0000 https://towardsdatascience.com/practical-computer-simulations-for-product-analysts-4d3a17957f64/ Part 2: Using bootstrap for observations and A/B tests

The post Practical Computer Simulations for Product Analysts appeared first on Towards Data Science.

]]>
In the first part of this series, we’ve discussed the basic ideas of computer simulations and how you can leverage them to answer "what-if" questions. It’s impossible to talk about simulations without bootstrap.

Bootstrap in statistics is a practical computer method for estimating the statistics of probability distributions. It is based on the repeated generation of samples using the Monte Carlo method from an existing sample. This method allows for simple and fast estimation of various statistics (such as confidence intervals, variance, correlation, etc.) for complex models.

When I learned about bootstrap in the statistics course, it felt a bit hacky. Instead of learning multiple formulas and criteria for different cases, you can just write a couple of lines of code and get confidence interval estimations for any custom and complicated use case. It sounds like magic.

And it really is. Now, when even your laptop can run thousands of simulations in minutes or even seconds, bootstrap is a powerful tool in your analytical toolkit that can help you in many situations. So, I believe that it’s worth learning or refreshing your knowledge about it.

In this article, we will talk about the idea behind bootstrap, understand when you should use it, learn how to get confidence intervals for different metrics and analyse the results of A/B tests.

What is bootstrap?

Actually, bootstrap is exceptionally straightforward. We need to run simulations drawing elements from our sample distribution with replacement, and then we can make conclusions based on this distribution.

Let’s look at the simple example when we have four elements: 1, 2, 3 and 4. Then, we can simulate many other collections of 4 elements where each element might be 1, 2, 3 or 4 with equal probabilities and use these simulations to understand, for example, how the mean value might change.

The statistical meaning behind bootstrap is that we consider that the actual population has precisely the same distribution as our sample (or the population consists of an infinite number of our sample copies). Then, we just assume that we know the general population and use it to understand the variability in our data.

Usually, when using a classical statistical approach, we assume that our variable follows some known distribution (for example, normal). However, we don’t need to make any assumptions regarding the nature of the distribution in Bootstrap. It’s pretty handy and helps to analyse even very complex custom metrics.

It’s almost impossible to mess up the bootstrap estimations. So, in many cases, I would prefer it to the classical statistical methods. The only drawback is computational time. If you’re working with big data, simulations might take hours, while you can get classical statistics estimations within seconds.

However, there are cases when it’s pretty challenging to get estimations without bootstrap. Let’s discuss the best use cases for bootstrap:

  • if you have outliers or influential points in your data;
  • if your sample is relatively small (roughly less than 100 cases);
  • if your data distribution is quite far from normal or other theoretical distribution, for example, it has several modes;
  • if you’re working with custom metrics (for example, the share of cases closed within SLA or percentiles).

Bootstrap is a wonderful and powerful statistical concept. Let’s try to use it for descriptive statistics.

Working with observational data

First, let’s start with the observational data and work with a synthetic dataset. Imagine we are helping a fitness club to set up a new fitness program that will help clients prepare for the London Marathon. We got the first trial group of 12 customers and measured their results.

Here is the data we have.

We collected just three fields for each of the 12 customers:

  • races_before – numbers of races customers had before our program,
  • kms_during_program – kilometres clients run during our program,
  • finished_marathon – whether the program was successful and a customer has finished the London Marathon.

We aim to set up a goal-focused fair program that incentivises our clients to train with us more and achieve better results. So, we would like to return the money if the client has run at least 150 kilometres during the preparation but couldn’t complete the marathon. However, before launching this program, we would like to make some estimations: what distance clients cover during preparation and the estimated share of refunds. We need it to ensure that our business is profitable and sustainable.

Estimating average

Let’s start with estimating the average distance. We can try to leverage our knowledge of mathematical statistics and use formulas for confidence intervals.

To do so, we need to make an assumption about the distribution of this variable. The most commonly used is a normal distribution. Let’s try it.

import numpy as np
from scipy.stats import norm, t

def get_normal_confidence_interval(data, confidence=0.95):
    # Calculate sample mean and standard deviation
    sample_mean = np.mean(data)
    sample_std = np.std(data, ddof=1)  
    n = len(data)

    # Calculate the critical value (z) based on the confidence level
    z = norm.ppf((1 + confidence) / 2)

    # Calculate the margin of error using standard error
    margin_of_error = z * sample_std / np.sqrt(n)

    # Calculate the confidence interval
    lower_bound = sample_mean - margin_of_error
    upper_bound = sample_mean + margin_of_error

    return lower_bound, upper_bound

get_normal_confidence_interval(df.kms_during_program.values)
# (111.86, 260.55)

The other option often used with real-life data is t-test distribution, which gives a broader confidence interval (since it assumes fatter tales than normal distribution).

def get_ttest_confidence_interval(data, confidence=0.95):
    # Calculate sample mean and standard deviation
    sample_mean = np.mean(data)
    sample_std = np.std(data, ddof=1)  
    n = len(data)

    # Calculate the critical value (z) based on the confidence level
    z = t.ppf((1 + confidence) / 2, df=len(data) - 1)

    # Calculate the margin of error using standard error
    margin_of_error = z * sample_std / np.sqrt(n)

    # Calculate the confidence interval
    lower_bound = sample_mean - margin_of_error
    upper_bound = sample_mean + margin_of_error

    return lower_bound, upper_bound

get_ttest_confidence_interval(df.kms_during_program.values)
# (102.72, 269.69)

We have a few examples in our sample. Also, there’s an outlier: a client with 12 races who managed to run almost 600 km preparing for the marathon, while most other clients run less than 200 km.

So, it’s an excellent case to use the bootstrap technique to understand the distribution and confidence interval better.

Let’s create a function to calculate and visualise the confidence interval:

  • We run num_batches simulations, doing samples with replacement, and calculating the average distance.
  • Then, based on these variables, we can get a 95% confidence interval: 2.5% and 97.5% percentiles of this distribution.
  • Finally, we can visualise the distribution on a chart.
import tqdm
import matplotlib.pyplot as plt

def get_kms_confidence_interval(num_batches, confidence = 0.95):
    # Running simulations
    tmp = []
    for i in tqdm.tqdm(range(num_batches)):
        tmp_df = df.sample(df.shape[0], replace = True)
        tmp.append(
            {
                'iteration': i,
                'mean_kms': tmp_df.kms_during_program.mean()
            }
        )
    # Saving data
    bootstrap_df = pd.DataFrame(tmp)

    # Calculating confidence interval
    lower_bound = bootstrap_df.mean_kms.quantile((1 - confidence)/2)
    upper_bound = bootstrap_df.mean_kms.quantile(1 - (1 - confidence)/2)

    # Creating a chart
    ax = bootstrap_df.mean_kms.hist(bins = 50, alpha = 0.6, 
        color = 'purple')
    ax.set_title('Average kms during the program, iterations = %d' % num_batches)

    plt.axvline(x=lower_bound, color='navy', linestyle='--', 
        label='lower bound = %.2f' % lower_bound)

    plt.axvline(x=upper_bound, color='navy', linestyle='--', 
        label='upper bound = %.2f' % upper_bound)

    ax.annotate('CI lower bound: %.2f' % lower_bound, 
                xy=(lower_bound, ax.get_ylim()[1]), 
                xytext=(-10, -20), 
                textcoords='offset points',  
                ha='center', va='top',  
                color='navy', rotation=90) 
    ax.annotate('CI upper bound: %.2f' % upper_bound, 
                xy=(upper_bound, ax.get_ylim()[1]), 
                xytext=(-10, -20), 
                textcoords='offset points',  
                ha='center', va='top',  
                color='navy', rotation=90) 
    plt.xlim(ax.get_xlim()[0] - 20, ax.get_xlim()[1] + 20)
    plt.show()

Let’s start with a small number of batches to see the first results quickly.

get_kms_confidence_interval(100)

We got a bit narrower and skewed to the right confidence interval with bootstrap, which is in line with our actual distribution: (139.31, 297.99) vs (102.72, 269.69).

However, with 100 bootstrap simulations, the distribution is not very clear. Let’s try to add more iterations. We can see that our distribution consists of multiple modes – for samples with one occurrence of outliers, two occurrences, three, etc.

With more iterations, we can see more modes (since more occurrences of the outlier are rarer), but all the confidence intervals are pretty close.

In the case of bootstrap, adding more iterations doesn’t lead to overfitting (because each iteration is independent). I would think about it as increasing the resolution of your image.

Since our sample is small, running many simulations doesn’t take much time. Even 1 million bootstrap iterations take around 1 minute.

Estimating custom metrics

As we discussed, bootstrap is handy when working with metrics that are not as straightforward as averages. For example, you might want to estimate the median or share of tasks closed within SLA.

You might even use bootstrap for something more unusual. Imagine you want to give customers discounts if your delivery is late: 5% discount for 15 minutes delay, 10% – for 1 hour delay and 20% – for 3 hours delay.

Getting a confidence interval for such cases theoretically using plain statistics might be challenging, so bootstrap will be extremely valuable.

Let’s return to our running program and estimate the share of refunds (when a customer ran 150 km but didn’t manage to finish the marathon). We will use a similar function but will calculate the refund share for each iteration instead of the mean value.

import tqdm
import matplotlib.pyplot as plt

def get_refund_share_confidence_interval(num_batches, confidence = 0.95):
    # Running simulations
    tmp = []
    for i in tqdm.tqdm(range(num_batches)):
        tmp_df = df.sample(df.shape[0], replace = True)
        tmp_df['refund'] = list(map(
            lambda kms, passed: 1 if (kms &gt;= 150) and (passed == 0) else 0,
            tmp_df.kms_during_program,
            tmp_df.finished_marathon
        ))

        tmp.append(
            {
                'iteration': i,
                'refund_share': tmp_df.refund.mean()
            }
        )

    # Saving data
    bootstrap_df = pd.DataFrame(tmp)

    # Calculating confident interval
    lower_bound = bootstrap_df.refund_share.quantile((1 - confidence)/2)
    upper_bound = bootstrap_df.refund_share.quantile(1 - (1 - confidence)/2)

    # Creating a chart
    ax = bootstrap_df.refund_share.hist(bins = 50, alpha = 0.6, 
        color = 'purple')
    ax.set_title('Share of refunds, iterations = %d' % num_batches)
    plt.axvline(x=lower_bound, color='navy', linestyle='--',
        label='lower bound = %.2f' % lower_bound)
    plt.axvline(x=upper_bound, color='navy', linestyle='--', 
        label='upper bound = %.2f' % upper_bound)
    ax.annotate('CI lower bound: %.2f' % lower_bound, 
                xy=(lower_bound, ax.get_ylim()[1]), 
                xytext=(-10, -20), 
                textcoords='offset points',  
                ha='center', va='top',  
                color='navy', rotation=90) 
    ax.annotate('CI upper bound: %.2f' % upper_bound, 
                xy=(upper_bound, ax.get_ylim()[1]), 
                xytext=(-10, -20), 
                textcoords='offset points',  
                ha='center', va='top',  
                color='navy', rotation=90) 
    plt.xlim(-0.1, 1)
    plt.show()

Even with 12 examples, we got a 2+ times smaller confidence interval. We can conclude with 95% confidence that less than 42% of customers will be eligible for a refund.

That’s a good result with such a small amount of data. However, we can go even further and try to get an estimation of causal effects.

Estimation of effects

We have data about the previous races before this marathon, and we can see how this value is correlated with the expected distance. We can use bootstrap for this as well. We only need to add the linear regression step to our current process.

def get_races_coef_confidence_interval(num_batches, confidence = 0.95):
    # Running simulations
    tmp = []
    for i in tqdm.tqdm(range(num_batches)):
        tmp_df = df.sample(df.shape[0], replace = True)
        # Linear regression model
        model = smf.ols('kms_during_program ~ races_before', data = tmp_df).fit()

        tmp.append(
            {
                'iteration': i,
                'races_coef': model.params['races_before']
            }
        )

    # Saving data
    bootstrap_df = pd.DataFrame(tmp)

    # Calculating confident interval
    lower_bound = bootstrap_df.races_coef.quantile((1 - confidence)/2)
    upper_bound = bootstrap_df.races_coef.quantile(1 - (1 - confidence)/2)

    # Creating a chart
    ax = bootstrap_df.races_coef.hist(bins = 50, alpha = 0.6, color = 'purple')
    ax.set_title('Coefficient between kms during the program and previous races, iterations = %d' % num_batches)
    plt.axvline(x=lower_bound, color='navy', linestyle='--', label='lower bound = %.2f' % lower_bound)
    plt.axvline(x=upper_bound, color='navy', linestyle='--', label='upper bound = %.2f' % upper_bound)
    ax.annotate('CI lower bound: %.2f' % lower_bound, 
                xy=(lower_bound, ax.get_ylim()[1]), 
                xytext=(-10, -20), 
                textcoords='offset points',  
                ha='center', va='top',  
                color='navy', rotation=90) 
    ax.annotate('CI upper bound: %.2f' % upper_bound, 
                xy=(upper_bound, ax.get_ylim()[1]), 
                xytext=(10, -20), 
                textcoords='offset points',  
                ha='center', va='top',  
                color='navy', rotation=90) 
    # plt.legend() 
    plt.xlim(ax.get_xlim()[0] - 5, ax.get_xlim()[1] + 5)
    plt.show()

    return bootstrap_df

We can look at the distribution. The confidence interval is above 0, so we can say there’s an effect with 95% confidence.

You can spot that distribution is bimodal, and each mode corresponds to one of the scenarios:

  • The component around 12 is related to samples without an outlier – it’s an estimation of the effect of previous races on the expected distance during the program if we disregard the outlier.
  • The second component corresponds to the samples when one or several outliers were in the dataset.

So, it’s super cool that we can make even estimations for different scenarios if we look at the bootstrap distribution.

We’ve learned how to use bootstrap with observational data, but its bread and butter is A/B testing. So, let’s move on to our second example.

Simulations for A/B testing

The other everyday use case for bootstrap is designing and analysing A/B tests. Let’s look at the example. It will also be based on a synthetic dataset that shows the effect of the discount on customer retention. Imagine we are working on an e-grocery product and want to test whether our marketing campaign with a 20 EUR discount will affect customers’ spending.

About each customer, we know his country of residence, the number of family members that live with them, the average annual salary in the country, and how much money they spend on products in our store.

Power analysis

First, we need to design the experiment and understand how many clients we need in each experiment group to make conclusions confidently. This step is called power analysis.

Let’s quickly recap the basic statistical theory about A/B tests and main metrics. Every test is based on the null hypothesis (which is the current status quo). In our case, the null hypothesis is "discount does not affect customers’ spending on our product". Then, we need to collect data on customers’ spending for control and experiment groups and estimate the probability of seeing such or more extreme results if the null hypothesis is valid. This probability is called the p-value, and if it’s small enough, we can conclude that we have enough data to reject the null hypothesis and say that treatment affects customers’ spending or retention.

In this approach, there are three main metrics:

  • effect size – the minimal change in our metric we would like to be able to detect,
  • statistical significance equals the false positive rate (probability of rejecting the null hypothesis when there was no effect). The most commonly used significance is 5%. However, you might choose other values depending on your false-positive tolerance. For example, if implementing the change is expensive, you might want to use a lower significance threshold.
  • statistical power shows the probability of rejecting the null hypothesis given that we actually had an effect equal to or higher than the effect size. People often use an 80% threshold, but in some cases (i.e. you want to be more confident that there are no negative effects), you might use 90% or even 99%.

We need all these values to estimate the number of clients in the experiment. Let’s try to define them in our case to understand their meaning better.

We will start with effect size:

  • we expect the retention rate to change by at least 3% points as a result of our campaign,
  • we would like to spot changes in customers’ spending by 20 or more EUR.

For statistical significance, I will use the default 5% threshold (so if we see the effect as a result of A/B test analysis, we can be confident with 95% that the effect is present). Let’s target a 90% statistical power threshold so that if there’s an actual effect equal to or bigger than the effect size, we will spot this change in 90% of cases.

Let’s start with statistical formulas that will allow us to get estimations quickly. Statistical formulas imply that our variable has a particular distribution, but they can usually help you estimate the magnitude of the number of samples. Later, we will use bootstrap to get more accurate results.

For retention, we can use the standard test of proportions. We need to know the actual value to estimate the normed effect size. We can get it from the historical data before the experiment.

import statsmodels.stats.power as stat_power
import statsmodels.stats.proportion as stat_prop

base_retention = before_df.retention.mean()
ret_effect_size = stat_prop.proportion_effectsize(base_retention + 0.03, 
    base_retention)

sample_size = 2*stat_power.tt_ind_solve_power(
    effect_size = ret_effect_size,
    alpha = 0.05, power = 0.9,
    nobs1 = None, # we specified nobs1 as None to get an estimation for it
    alternative='larger'
)

# ret_effect_size = 0.0632, sample_size = 8573.86

We used a one-sided test because there’s no difference in whether there’s a negative or no effect from the business perspective since we won’t implement this change. Using a one-sided instead of a two-sided test increases the statistical power.

We can similarly estimate the sample size for the customer value, assuming the normal distribution. However, the distribution is not normal actually, so we should expect more precise results from bootstrap.

Let’s write code.

val_effect_size = 20/before_df.customer_value.std()

sample_size = 2*stat_power.tt_ind_solve_power(
    effect_size = val_effect_size,
    alpha = 0.05, power = 0.9, 
    nobs1 = None, 
    alternative='larger'
)
# val_effect_size = 0.0527, sample_size = 12324.13

We got estimations for the needed sample sizes for each test. However, there are cases when you have a limited number of clients and want to understand the statistical power you can get.

Suppose we have only 5K customers (2.5K in each group). Then, we will be able to achieve 72.2% statistical power for retention analysis and 58.7% – for customer value (given the desired statistical significance and effect sizes).

The only difference in the code is that this time, we’ve specified nobs1 = 2500 and left power as None.

stat_power.tt_ind_solve_power(
    effect_size = ret_effect_size,
    alpha = 0.05, power = None,
    nobs1 = 2500, 
    alternative='larger'
)
# 0.7223

stat_power.tt_ind_solve_power(
    effect_size = val_effect_size,
    alpha = 0.05, power = None,
    nobs1 = 2500, 
    alternative='larger'
)
# 0.5867

Now, it’s time to use bootstrap for the power analysis, and we will start with the customer value test since it’s easier to implement.

Let’s discuss the basic idea and steps of power analysis using bootstrap. First, we need to define our goal clearly. We want to estimate the statistical power depending on the sample size. If we put it in more practical terms, we want to know the percentage of cases when there was an increase in customer spending by 20 or more EUR, and we were able to reject the null hypothesis and implement this change in production. So, we need to simulate a bunch of such experiments and calculate the share of cases when we can see statistically significant changes in our metric.

Let’s look at one experiment and break it into steps. The first step is to generate the experimental data. For that, we need to get a random subset from the population equal to the sample size, randomly split these customers into control and experiment groups and add an effect equal to the effect size for the treatment group. All this logic is implemented in get_sample_for_value function below.

def get_sample_for_value(pop_df, sample_size, effect_size):
  # getting sample of needed size
  sample_df = pop_df.sample(sample_size)

  # randomly assign treatment
  sample_df['treatment'] = sample_df.index.map(
    lambda x: 1 if np.random.uniform() &gt; 0.5 else 0)

  # add efffect for the treatment group
  sample_df['predicted_value'] = sample_df['customer_value'] 
    + effect_size * sample_df.treatment

  return sample_df

Now, we can treat this synthetic experiment data as we usually do with A/B test analysis, run a bunch of bootstrap simulations, estimate effects, and then get a confidence interval for this effect.

We will be using linear regression to estimate the effect of treatment. As discussed in the previous article, it’s worth adding to linear regression features that explain the outcome variable (customers’ spending). We will add the number of family members and average salary to the regression since they are positively correlated.

import statsmodels.formula.api as smf
val_model = smf.ols('customer_value ~ num_family_members + country_avg_annual_earning', 
    data = before_df).fit(disp = 0)
val_model.summary().tables[1]

We will put all the logic of doing multiple bootstrap simulations and estimating treatment effects into the get_ci_for_value function.

def get_ci_for_value(df, boot_iters, confidence_level):
    tmp_data = []

    for iter in range(boot_iters):
        sample_df = df.sample(df.shape[0], replace = True)
        val_model = smf.ols('predicted_value ~ treatment + num_family_members + country_avg_annual_earning', 
          data = sample_df).fit(disp = 0)
        tmp_data.append(
            {
                'iteration': iter,
                'coef': val_model.params['treatment']
            }
        )

    coef_df = pd.DataFrame(tmp_data)
    return coef_df.coef.quantile((1 - confidence_level)/2), 
        coef_df.coef.quantile(1 - (1 - confidence_level)/2)

The next step is to put this logic together, run a bunch of such synthetic experiments, and save results.

def run_simulations_for_value(pop_df, sample_size, effect_size, 
    boot_iters, confidence_level, num_simulations):

    tmp_data = []

    for sim in tqdm.tqdm(range(num_simulations)):
        sample_df = get_sample_for_value(pop_df, sample_size, effect_size)
        num_users_treatment = sample_df[sample_df.treatment == 1].shape[0]
        value_treatment = sample_df[sample_df.treatment == 1].predicted_value.mean()
        num_users_control = sample_df[sample_df.treatment == 0].shape[0]
        value_control = sample_df[sample_df.treatment == 0].predicted_value.mean()

        ci_lower, ci_upper = get_ci_for_value(sample_df, boot_iters, confidence_level)

        tmp_data.append(
            {
                'experiment_id': sim,
                'num_users_treatment': num_users_treatment,
                'value_treatment': value_treatment,
                'num_users_control': num_users_control,
                'value_control': value_control,
                'sample_size': sample_size,
                'effect_size': effect_size,
                'boot_iters': boot_iters,
                'confidence_level': confidence_level,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper
            }
        )

    return pd.DataFrame(tmp_data)

Let’s run this simulation for sample_size = 100 and see the results.

val_sim_df = run_simulations_for_value(before_df, sample_size = 100, 
    effect_size = 20, boot_iters = 1000, confidence_level = 0.95, 
    num_simulations = 20)
val_sim_df.set_index('simulation')[['sample_size', 'ci_lower', 'ci_upper']].head()

We’ve got the following data for 20 simulated experiments. We know the confidence interval for each experiment, and now we can estimate the power.

We would have rejected the null hypothesis if the lower bound of the confidence interval was above zero, so let’s calculate the share of such experiments.

val_sim_df['successful_experiment'] = val_sim_df.ci_lower.map(
  lambda x: 1 if x &gt; 0 else 0)

val_sim_df.groupby(['sample_size', 'effect_size']).aggregate(
    {
        'successful_experiment': 'mean',
        'experiment_id': 'count'
    }
)

We’ve started with just 20 simulated experiments and 1000 bootstrap simulations to estimate their confidence interval. Such a few simulations can help us get a low-resolution picture quite quickly. Keeping in mind the estimation we got from the classic statistics, we should expect that numbers around 10K will give us the desired statistical power.

tmp_dfs = []
for sample_size in [100, 250, 500, 1000, 2500, 5000, 10000, 25000]:
    print('Simulation for sample size = %d' % sample_size)
    tmp_dfs.append(
        run_simulations_for_value(before_df, sample_size = sample_size, effect_size = 20,
                              boot_iters = 1000, confidence_level = 0.95, num_simulations = 20)
    )

val_lowres_sim_df = pd.concat(tmp_dfs)

We got results similar to those of our theoretical estimations. Let’s try to run estimations with more simulated experiments (100 and 500 experiments). We can see that 12.5K clients will be enough to achieve 90% statistical power.

I’ve added all the power analysis results to the chart so that we can see the relation clearly.

In that case, you might already see that bootstrap can take a significant amount of time. For example, accurately estimating power with 500 experiment simulations for just 3 sample sizes took me almost 2 hours.

Now, we can estimate the relationship between effect size and power for a 12.5K sample size.

tmp_dfs = []
for effect_size in [1, 5, 10, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100]:
    print('Simulation for effect size = %d' % effect_size)
    tmp_dfs.append(
        run_simulations_for_value(before_df, sample_size = 12500, effect_size = effect_size,
                              boot_iters = 1000, confidence_level = 0.95, num_simulations = 100)
    )

val_effect_size_sim_df = pd.concat(tmp_dfs)

We can see that if the actual effect on customers’ spending is higher than 20 EUR, we will get even higher statistical power, and we will be able to reject the null hypothesis in more than 90% of cases. But we will be able to spot the 10 EUR effect in less than 50% of cases.

Let’s move on and conduct power analysis for retention as well. The complete code is structured similarly to the customer spending analysis. We will discuss nuances in detail below.

import tqdm

def get_sample_for_retention(pop_df, sample_size, effect_size):
    base_ret_model = smf.logit('retention ~ num_family_members', data = pop_df).fit(disp = 0)
    tmp_pop_df = pop_df.copy()
    tmp_pop_df['predicted_retention_proba'] = base_ret_model.predict()
    sample_df = tmp_pop_df.sample(sample_size)
    sample_df['treatment'] = sample_df.index.map(lambda x: 1 if np.random.uniform() &gt; 0.5 else 0)
    sample_df['predicted_retention_proba'] = sample_df['predicted_retention_proba'] + effect_size * sample_df.treatment
    sample_df['retention'] = sample_df.predicted_retention_proba.map(lambda x: 1 if x &gt;= np.random.uniform() else 0)
    return sample_df

def get_ci_for_retention(df, boot_iters, confidence_level):
    tmp_data = []

    for iter in range(boot_iters):
        sample_df = df.sample(df.shape[0], replace = True)
        ret_model = smf.logit('retention ~ treatment + num_family_members', data = sample_df).fit(disp = 0)
        tmp_data.append(
            {
                'iteration': iter,
                'coef': ret_model.params['treatment']
            }
        )

    coef_df = pd.DataFrame(tmp_data)
    return coef_df.coef.quantile((1 - confidence_level)/2), coef_df.coef.quantile(1 - (1 - confidence_level)/2)

def run_simulations_for_retention(pop_df, sample_size, effect_size, 
                                  boot_iters, confidence_level, num_simulations):
    tmp_data = []

    for sim in tqdm.tqdm(range(num_simulations)):
        sample_df = get_sample_for_retention(pop_df, sample_size, effect_size)
        num_users_treatment = sample_df[sample_df.treatment == 1].shape[0]
        retention_treatment = sample_df[sample_df.treatment == 1].retention.mean()
        num_users_control = sample_df[sample_df.treatment == 0].shape[0]
        retention_control = sample_df[sample_df.treatment == 0].retention.mean()

        ci_lower, ci_upper = get_ci_for_retention(sample_df, boot_iters, confidence_level)

        tmp_data.append(
            {
                'experiment_id': sim,
                'num_users_treatment': num_users_treatment,
                'retention_treatment': retention_treatment,
                'num_users_control': num_users_control,
                'retention_control': retention_control,
                'sample_size': sample_size,
                'effect_size': effect_size,
                'boot_iters': boot_iters,
                'confidence_level': confidence_level,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper
            }
        )

    return pd.DataFrame(tmp_data)

First, since we have a binary outcome for retention (whether the customer returns next month or not), we will use a logistic regression model instead of linear regression. We can see that retention is correlated with the size of the family. It might be the case that when you buy many different types of products for family members, it’s more difficult to find another service that will cover all your needs.

base_ret_model = smf.logit('retention ~ num_family_members', data = before_df).fit(disp = 0)
base_ret_model.summary().tables[1]

Also, the functionget_sample_for_retention has a bit trickier logic to adjust results for the treatment group. Let’s look at it step by step.

First, we are fitting a logistic regression on the whole population data and using this model to predict the probability of retaining using this model.

base_ret_model = smf.logit('retention ~ num_family_members', data = pop_df)
  .fit(disp = 0)
tmp_pop_df = pop_df.copy()
tmp_pop_df['predicted_retention_proba'] = base_ret_model.predict()

Then, we got a random sample equal to the size and split it into a control and test group.

sample_df = tmp_pop_df.sample(sample_size)
sample_df['treatment'] = sample_df.index.map(
  lambda x: 1 if np.random.uniform() &gt; 0.5 else 0)

For the treatment group, we increase the probability of retaining by the expected effect size.

sample_df['predicted_retention_proba'] = sample_df['predicted_retention_proba'] 
    + effect_size * sample_df.treatment

The last step is to define, based on probability, whether the customer is retained or not. We used uniform distribution (random number between 0 and 1) for that:

  • if a random value from a uniform distribution is below probability, then a customer is retained (it happens with specified probability),
  • otherwise, the customer has churned.
sample_df['retention'] = sample_df.predicted_retention_proba.map(
    lambda x: 1 if x &gt; np.random.uniform() else 0)

You can run a few simulations to ensure our sampling function works as intended. For example, with this call, we can see that for the control group, retention is equal to 64% like in the population, and it’s 93.7% for the experiment group (as expected with effect_size = 0.3 )

get_sample_for_retention(before_df, 10000, 0.3)
  .groupby('treatment', as_index = False).retention.mean()

# |    |   treatment |   retention |
# |---:|------------:|------------:|
# |  0 |           0 |    0.640057 |
# |  1 |           1 |    0.937648 |

Now, we can also run simulations to see the optimal number of samples to reach 90% of statistical power for retention. We can see that the 12.5K sample size also will be good enough for retention.

Analysing results

We can use linear or logistic regression to analyse results or leverage the functions we already have for bootstrap CI.

value_model = smf.ols(
  'customer_value ~ treatment + num_family_members + country_avg_annual_earning', 
  data = experiment_df).fit(disp = 0)
value_model.summary().tables[1]

So, we got the statistically significant result for the customer spending equal to 25.84 EUR with a 95% confidence interval equal to (16.82, 34.87) .

With the bootstrap function, the CI will be pretty close.

get_ci_for_value(experiment_df.rename(
    columns = {'customer_value': 'predicted_value'}), 1000, 0.95)
# (16.28, 34.63)

Similarly, we can use logistic regression for retention analysis.

retention_model = smf.logit('retention ~ treatment + num_family_members',
    data = experiment_df).fit(disp = 0)
retention_model.summary().tables[1]

Again, the bootstrap approach gives close estimations for CI.

get_ci_for_retention(experiment_df, 1000, 0.95)
# (0.072, 0.187)

With logistic regression, it might be tricky to interpret the coefficient. However, we can use a hacky approach: for each customer in our dataset, calculate probability in case the customer was in control and treatment using our model and then look at the average difference between probabilities.

experiment_df['treatment_eq_1'] = 1
experiment_df['treatment_eq_0'] = 0

experiment_df['retention_proba_treatment'] = retention_model.predict(
    experiment_df[['retention', 'treatment_eq_1', 'num_family_members']]
        .rename(columns = {'treatment_eq_1': 'treatment'}))

experiment_df['retention_proba_control'] = retention_model.predict(
    experiment_df[['retention', 'treatment_eq_0', 'num_family_members']]
      .rename(columns = {'treatment_eq_0': 'treatment'}))

experiment_df['proba_diff'] = experiment_df.retention_proba_treatment 
    - experiment_df.retention_proba_control

experiment_df.proba_diff.mean()
# 0.0281

So, we can estimate the effect on retention to be 2.8%.

Congratulations! We’ve finally finished the full A/B test analysis and were able to estimate the effect both on average customer spending and retention. Our experiment is successful, so in real life, we would start thinking about rolling it to production.

You can find the full code for this example on GitHub.

Summary

Let me quickly recap what we’ve discussed today:

  • The main idea of bootstrap is simulations with replacements from your sample, assuming that the general population has the same distribution as the data we have.
  • Bootstrap shines in cases when you have few data points, your data has outliers or is far from any theoretical distribution. Bootstrap can also help you estimate custom metrics.
  • You can use bootstrap to work with observational data, for example, to get confidence intervals for your values.
  • Also, bootstrap is broadly used for A/B testing analysis – both to estimate the impact of treatment and do a power analysis to design an experiment.

Thank you a lot for reading this article. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

This article was inspired by the book "Behavioral Data Analysis with R and Python" by Florent Buisson.

The post Practical Computer Simulations for Product Analysts appeared first on Towards Data Science.

]]>
Practical Computer Simulations for Product Analysts https://towardsdatascience.com/practical-computer-simulations-for-product-analysts-90b5deb6a54e/ Fri, 19 Apr 2024 23:54:20 +0000 https://towardsdatascience.com/practical-computer-simulations-for-product-analysts-90b5deb6a54e/ Part 1: Task-specific approaches for scenario forecasting

The post Practical Computer Simulations for Product Analysts appeared first on Towards Data Science.

]]>
Image by DALL-E
Image by DALL-E

In Product Analytics, we quite often get "what-if" questions. Our teams are constantly inventing different ways to improve the product and want to understand how it can affect our KPI or other metrics.

Let’s look at some examples:

  • Imagine we’re in the fintech industry and facing new regulations requiring us to check more documents from customers making the first donation or sending more than $100K to a particular country. We want to understand the effect of this change on our Ops demand and whether we need to hire more agents.
  • Let’s switch to another industry. We might want to incentivise our taxi drivers to work late or take long-distance rides by introducing a new reward scheme. Before launching this change, it would be crucial for us to estimate the expected size of rewards and conduct a cost vs. benefit analysis.
  • As the last example, let’s look at the main Customer Support KPIs. Usually, companies track the average waiting time. There are many possible ways how to improve this metric. We can add night shifts, hire more agents or leverage LLMs to answer questions quickly. To prioritise these ideas, we will need to estimate their impact on our KPI.

When you see such questions for the first time, they look pretty intimidating.

If someone asks you to calculate monthly active users or 7-day retention, it’s straightforward. You just need to go to your database, write SQL and use the data you have.

Things become way more challenging (and exciting) when you need to calculate something that doesn’t exist. Computer simulations will usually be the best solution for such tasks. According to Wikipedia, Simulation is an imitative representation of a process or system that could exist in the real world. So, we will try to imitate different situations and use them in our decision-making.

Simulation is a powerful tool that can help you in various situations. So, I would like to share with you the practical examples of computer simulations in the series of articles:

  • In this article, we will discuss how to use simulations to estimate different scenarios. You will learn the basic idea of simulations and see how they can solve complex tasks.
  • In the second part, we will diverge from scenario analysis and will focus on the classic of computer simulations – bootstrap. Bootstrap can help you get confidence intervals for your metrics and analyse A/B tests.
  • I would like to devote the third part to agent-based models. We will model the CS agent behaviour to understand how our process changes can affect CS KPIs such as queue size or average waiting time.

So, it’s time to start and discuss the task we will solve in this article.

Our project: Launching tests for English courses

Suppose we are working on an edtech product that helps people learn the English language. We’ve been working on a test that could assess the student’s knowledge from different angles (reading, listening, writing and speaking). The test will give us and our students a clear understanding of their current level.

We agreed to launch it for all new students so that we can assess their initial level. Also, we will suggest existing students pass this test when they return to the service next time.

Our goal is to build a forecast on the number of submitted tests over time. Since some parts of these tests (writing and speaking) will require manual review from our teachers, we would like to ensure that we will have enough capacity to check these tests on time.

Let’s try to structure our problem. We have two groups of students:

  • The first group is existing students. It’s a good practice to be precise in analytics, so we will define them as students who started using our service before this launch. We will need to check them once at their next transaction, so we will have a substantial spike while processing them all. Later, the demand from this segment will be negligible (only rare reactivations).
  • New students will hopefully continue joining our courses. So, we should expect consistent demand from this group.

Now, it’s time to think about how we can estimate the demand for these two groups of customers.

The situation is pretty straightforward for new students – we need to predict the number of new customers weekly and use it to estimate demand. So, it’s a classic task of time series forecasting.

The task of predicting demand from existing customers might be more challenging. The direct approach would be to build a model to predict the week when students will return to the service next time and use it for estimations. It’s a possible solution, but it sounds a bit overcomplicated to me.

I would prefer the other approach. I would simulate the situation when we launched this test some time ago and use the previous data. In that case, we will have all the data after "this simulated launch" and will be able to calculate all the metrics. So, it’s actually a basic idea of scenario simulations.

Cool, we have a plan. Let’s move on to execution.

Modelling demand from new customers

Before jumping to analysis, let’s examine the data we have. We keep a record of the lessons’ completion events. We know each event’s user identifier, date, module, and lesson number. We will use weekly data to avoid seasonality and capture meaningful trends.

Let me share some context about the educational process. Students primarily come to our service to learn English from scratch and pass six modules (from pre-A1 to C1). Each module consists of 100 lessons.

The data was generated explicitly for this use case, so we are working with a synthetic data set.

First, we need to calculate the metric we want to predict. We will offer students the opportunity to pass the initial evaluation test after completing the first demo lesson. So, we can easily calculate the number of customers who passed the first lesson or aggregate users by their first date.

new_users_df = df.groupby('user_id', as_index = False).date.min()
  .rename(columns = {'date': 'cohort'})

new_users_stats_df = new_users_df.groupby('cohort')[['user_id']].count()
  .rename(columns = {'user_id': 'new_users'})

We can look at the data and see an overall growing trend with some seasonal effects (i.e. fewer customers joining during the summer or Christmas time).

For forecasting, we will use Prophet – an open-source library from Meta. It works pretty well with business data since it can predict non-linear trends and automatically take into account seasonal effects. You can easily install it from PyPI.

pip install prophet

Prophet library expects a data frame with two columns: ds with timestamp and y with a metric we want to predict. Also, ds must be a datetime column. So, we need to transform our data to the expected format.

pred_new_users_df = new_users_df.copy()
pred_new_users_df = pred_new_users_df.rename(
  columns = {'new_users': 'y', 'cohort': 'ds'})
pred_new_users_df.ds = pd.to_datetime(pred_new_users_df.ds)

Now, we are ready to make predictions. As usual in ML, we need to initialise and fit a model.

from prophet import Prophet

m = Prophet()
m.fit(pred_new_users_df)

The next step is prediction. First, we need to create a future data frame specifying the number of periods and their frequency (in our case, weekly). Then, we need to call the predict function.

future = m.make_future_dataframe(periods= 52, freq = 'W')
forecast_df = m.predict(future)
forecast_df.tail()[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]

As a result, we get the forecast (yhat) and confidence interval (yhat_lower and yhat_upper).

It’s difficult to understand the result without charts. Let’s use Prophet functions to visualise the output better.

m.plot(forecast_df) # forecast
m.plot_components(forecast_df) # components

The forecast chart shows you the forecast with a confidence interval.

The components view lets you understand the split between trend and seasonal effects. For example, the second chart displays a seasonal drop-off during summer and an increase at the beginning of September (when people might be more motivated to start learning something new).

We can put all this forecasting logic into one function. It will be helpful for us later.

import plotly.express as px
import plotly.io as pio
pio.templates.default = 'simple_white'

def make_prediction(tmp_df, param, param_name = '', periods = 52):
    # pre-processing
    df = tmp_df.copy()
    date_param = df.index.name
    df.index = pd.to_datetime(df.index)

    train_df = df.reset_index().rename(columns = {date_param: 'ds', param: 'y'})

    # model
    m = Prophet()
    m.fit(train_df)

    future = m.make_future_dataframe(periods=periods, freq = 'W')
    forecast = m.predict(future)
    forecast = forecast[['ds', 'yhat']].rename(columns = {'ds': date_param, 'yhat': param + '_model'})

    # join to actual data
    forecast = forecast.set_index(date_param).join(df, how = 'outer')

    # visualisation
    fig = px.line(forecast, 
        title = '<b>Forecast:</b> ' + (param if param_name == '' else param_name),
        labels = {'value': param if param_name == '' else param_name},
        color_discrete_map = {param: 'navy', param + '_model': 'gray'}
    )
    fig.update_traces(mode='lines', line=dict(dash='dot'), 
        selector=dict(name=param + '_model'))
    fig.update_layout(showlegend = False)
    fig.show()

    return forecast

new_forecast_df = make_prediction(new_users_stats_df, 
  'new_users', 'new users', periods = 75)

I prefer to share with my stakeholders a more styled version of visualisation (especially for public presentations), so I’ve added it to the function as well.

In this example, we’ve used the default Prophet model and got quite a plausible forecast. However, in some cases, you might want to tweak parameters, so I advise you to read the Prophet docs to learn more about the possible levers.

For example, in our case, we believe that our audience will continue growing at the same rate. However, this might not be the case, and you might expect it to have a cap of around 100 users. Let’s update our prediction for saturating growth.

# adding cap to the initial data
# it's not required to be constant
pred_new_users_df['cap'] = 100

#specifying logistic growth
m = Prophet(growth='logistic')
m.fit(pred_new_users_df)

# adding cap for the future
future = m.make_future_dataframe(periods= 52, freq = 'W')
future['cap'] = 100
forecast_df = m.predict(future)

We can see that the forecast has changed significantly, and the growth stops at ~100 new clients per week.

It’s also interesting to look at the components’ chart in this case. We can see that the seasonal effects stayed the same, while the trend has changed to logistic (as we specified).

We’ve learned a bit about the ability to tweak forecasts. However, for future calculations, we will use a basic model. Our business is still relatively small, and most likely, we haven’t reached saturation yet.

We’ve got all the needed estimations for new customers and are ready to move on to the existing ones.

Modelling demand from existing customers

The first version

The key point in our approach is to simulate the situation when we launched this test some time ago and calculate the demand using this data. Our solution is based on the idea that we can use the past data instead of predicting the future.

Since there’s significant yearly seasonality, I will use data for -1 year to take into account these effects automatically. We want to launch this project at the beginning of April. So, I will use past data from the week of 2nd April 2023.

First, we need to filter the data related to existing customers at the beginning of April 2023. We’ve already forecasted demand from new users, so we don’t need to consider them in this estimation.

model_existing_users = df[df.date < '2023-04-02'].user_id.unique()
raw_existing_df = df[df.user_id.isin(model_existing_users)]

Then, we need to model the demand from these users. We will offer our existing students the chance to pass the test the next time they use our product. So, we need to define when each customer returned to our service after the launch and aggregate the number of customers by week. There’s no rocket science at all.

existing_model_df = raw_existing_df[raw_existing_df.date >= '2023-04-02']
  .groupby('user_id', as_index = False).date.min()
  .groupby('date', as_index = False).user_id.count()
  .rename(columns = {'user_id': 'existing_users'})

We got the first estimations. If we had launched this test in April 2023, we would have gotten around 1.3K tests in the first week, 0.3K for the second week, 80 cases in the third week, and even less afterwards.

We assumed that 100% of existing customers would finish the test, and we would need to check it. In real-life tasks, it’s worth taking conversion into account and adjusting the numbers. Here, we will continue using 100% conversion for simplicity.

So, we’ve done our first modelling. It wasn’t challenging at all. But is this estimation good enough?

Taking into account long-term trends

We are using data from the previous year. However, everything changes. Let’s look at the number of active customers over time.

active_users_df = df.groupby('date')[['user_id']].nunique()
    .rename(columns = {'user_id': 'active_users'})

We can see that it’s growing steadily. I would expect it to continue growing. So, it’s worth adjusting our forecast due to this YoY (Year-over-Year) growth. We can re-use our prediction function and calculate YoY using forecasted values to make it more accurate.


active_forecast_df = make_prediction(active_users_df, 
    'active_users', 'active users')

Let’s calculate YoY growth based on our forecast and adjust the model’s predictions.

# calculating YoYs
active_forecast_df['active_user_prev_year'] = active_forecast_df.active_users.shift(52)
active_forecast_df['yoy'] = active_forecast_df.active_users_model/
  active_forecast_df.active_user_prev_year

existing_model_df = existing_model_df.rename(
  columns = {'date': 'model_date', 'existing_users': 'model_existing_users'})

# adjusting dates from 2023 to 2024
existing_model_df['date'] = existing_model_df.model_date.map(
  lambda x: datetime.datetime.strptime(x, '%Y-%m-%d') + datetime.timedelta(364)
)

existing_model_df = existing_model_df.set_index('date')
   .join(active_forecast_df[['yoy']])

# updating estimations
existing_model_df['existing_users'] = list(map(
    lambda x, y: int(round(x*y)),
    existing_model_df.model_existing_users,
    existing_model_df.yoy
))

We’ve finished the estimations for the existing students as well. So, we are ready to merge both parts and get the result.

Putting everything together

First results

Now, we can combine all our previous estimations and see the final chart. For that, we need to convert data to the common format and add segments so that we can distinguish demand between new and existing students.

# existing segment
existing_model_df = existing_model_df.reset_index()[['date', 'existing_users']]
  .rename(columns = {'existing_users': 'users'})
existing_model_df['segment'] = 'existing'

# new segment
new_model_df = new_forecast_df.reset_index()[['cohort', 'new_users_model']]
  .rename(columns = {'cohort': 'date', 'new_users_model': 'users'})
new_model_df = new_model_df[(new_model_df.date >= '2024-03-31') 
  &amp; (new_model_df.date < '2025-04-07')]
new_model_df['users'] = new_model_df.users.map(lambda x: int(round(x)))
new_model_df['segment'] = 'new'

# combining everything
demand_model_df = pd.concat([existing_model_df, new_model_df])

# visualisation
px.area(demand_model_df.pivot(index = 'date', 
          columns = 'segment', values = 'users').head(15)[['new', 'existing']], 
    title = '<b>Demand</b>: modelling number of tests after launch',
    labels = {'value': 'number of test'})

We should expect around 2.5K tests for the first week after launch, mostly from existing customers. Then, within four weeks, we will review tests from existing users and will have only ~100–130 cases per week from new joiners.

That’s wonderful. Now, we can share our estimations with colleagues so they can also plan their work.

What if we have demand constraints?

In real life, you will often face the problem of capacity constraints when it’s impossible to launch a new feature to 100% of customers. So, it’s time to learn how to deal with such situations.

Suppose we’ve found out that our teachers can check only 1K tests each week. Then, we need to stagger our demand to avoid bad customer experience (when students need to wait for weeks to get their results).

Luckily, we can do it easily by rolling out tests to our existing customers in batches (or cohorts). We can switch the functionality on for all new joiners and X% of existing customers in the first week. Then, we can add another Y% of existing customers in the second week, etc. Eventually, we will evaluate all existing students and have ongoing demand only from new users.

Let’s come up with a rollout plan without exceeding the 1K capacity threshold.

Since we definitely want to launch it for all new students, let’s start with them and add them to our plan. We will store all demand estimations by segments in the raw_demand_est_model_df data frame and initialise them with our new_model_df estimations that we got before.

raw_demand_est_model_df = new_model_df.copy()

Now, we can aggregate this data and calculate the remaining capacity.

capacity = 1000

demand_est_model_df = raw_demand_est_model_df.pivot(index = 'date', 
    columns = 'segment', values = 'users')

demand_est_model_df['total_demand'] = demand_est_model_df.sum(axis = 1)
demand_est_model_df['capacity'] = capacity
demand_est_model_df['remaining_capacity'] = demand_est_model_df.capacity 
    - demand_est_model_df.total_demand

demand_est_model_df.head()

Let’s put this logic into a separate function since we will need it to evaluate our estimations after each iteration.

import plotly.graph_objects as go

def get_total_demand_model(raw_demand_est_model_df, capacity = 1000):
    demand_est_model_df = raw_demand_est_model_df.pivot(index = 'date', 
        columns = 'segment', values = 'users')
    demand_est_model_df['total_demand'] = demand_est_model_df.sum(axis = 1)
    demand_est_model_df['capacity'] = capacity
    demand_est_model_df['remaining_capacity'] = demand_est_model_df.capacity 
      - demand_est_model_df.total_demand

    tmp_df = demand_est_model_df.drop(['total_demand', 'capacity', 
        'remaining_capacity'], axis = 1)
    fig = px.area(tmp_df,
                 title = '<b>Demand vs Capacity</b>',
                  category_orders={'segment': ['new'] + list(sorted(filter(lambda x: x != 'new', tmp_df.columns)))},
                 labels = {'value': 'tests'})
    fig.add_trace(go.Scatter(
        x=demand_est_model_df.index, y=demand_est_model_df.capacity, 
        name='capacity', line=dict(color='black', dash='dash'))
    )

    fig.show()
    return demand_est_model_df

demand_plan_df = get_total_demand_model(raw_demand_est_model_df)
demand_plan_df.head()

I’ve also added a chart to the output of this function that will help us to assess our results effortlessly.

Now, we can start planning the rollout for existing customers week by week.

First, let’s transform our current demand model for existing students. I would like it to be indexed by the sequence number of weeks and show the 100% demand estimation. Then, I can smoothly get estimations for each batch by multiplying demand by weight and calculating the dates based on the launch date and week number.

existing_model_df['num_week'] = list(range(existing_model_df.shape[0]))
existing_model_df = existing_model_df.set_index('num_week')
    .drop(['date', 'segment'], axis = 1)
existing_model_df.head()

So, for example, if we launch our evaluation test for 10% of random customers, then we expect to get 244 tests on the first week, 52 tests on the second week, 14 on the third, etc.

I will be using the same estimations for all batches. I assume that all batches of the same size will produce the exact number of tests over the following weeks. So, I don’t take into account any seasonal effects related to the launch date for each batch.

This assumption simplifies your process quite a bit. And it’s pretty reasonable in our case because we will do a rollout only within 4–5 weeks, and there are no significant seasonal effects during this period. However, if you want to be more accurate (or have considerable seasonality), you can build demand estimations for each batch by repeating our previous process.

Let’s start with the week of 31st March 2024. As we saw before, we have a spare capacity for 888 tests. If we launch our test to 100% of existing customers, we will get ~2.4K tests to check in the first week. So, we are ready to roll out only to a portion of all customers. Let’s calculate it.

cohort = '2024-03-31'
demand_plan_df.loc[cohort].remaining_capacity/existing_model_df.iloc[0].users
# 0.3638

It’s easier to operate with more round numbers, so let’s round the number to a fraction of 5%. I’ve rounded the number down to have some buffer.

full_demand_1st_week = existing_model_df.iloc[0].users
next_group_share = demand_plan_df.loc[cohort].remaining_capacity/full_demand_1st_week
next_group_share = math.floor(20*next_group_share)/20
# 0.35

Since we will make several iterations, we need to track the percentage of existing customers for whom we’ve enabled the new feature. Also, it’s worth checking whether we’ve already processed all the customers to avoid double-counting.

enabled_user_share = 0

# if we can process more customers than are left, update the number
if next_group_share > 1 - enabled_user_share:
    print('exceeded')
    next_group_share = round(1 - enabled_user_share, 2)

enabled_user_share += next_group_share
# 0.35

Also, saving our rollout plan in a separate variable will be helpful.

rollout_plan = []
rollout_plan.append(
    {'launch_date': cohort, 'rollout_percent': next_group_share}
)

Now, we need to estimate the expected demand from this batch. Launching tests for 35% of customers on 31st March will lead to some demand not only in the first week but also in the subsequent weeks. So, we need to calculate the total demand from this batch and add it to our plans.

# copy the model
next_group_demand_df = existing_model_df.copy().reset_index()

# calculate the dates from cohort + week number
next_group_demand_df['date'] = next_group_demand_df.num_week.map(
    lambda x: (datetime.datetime.strptime(cohort, '%Y-%m-%d') 
        + datetime.timedelta(7*x))
)

# adjusting demand by weight
next_group_demand_df['users'] = (next_group_demand_df.users * next_group_share).map(lambda x: int(round(x)))

# labelling the segment
next_group_demand_df['segment'] = 'existing, cohort = %s' % cohort

# updating the plan
raw_demand_est_model_df = pd.concat([raw_demand_est_model_df, 
    next_group_demand_df.drop('num_week', axis = 1)])

Now, we can re-use the function get_total_demand_mode, which helps us analyse the current demand vs capacity balance.

demand_plan_df = get_total_demand_model(raw_demand_est_model_df)
demand_plan_df.head()

We are utilising most of our capacity for the first week. We still have some free resources, but it was our conscious decision to keep some buffer for sustainability. We can see that there’s almost no demand from this batch after 3 weeks.

With that, we’ve finished the first iteration and can move on to the following week – 4th April 2024. We can check an additional 706 cases during this week.

We can repeat the whole process for this week and move to the next one. We can iterate to the point when we launch our project to 100% of existing customers (enabled_user_share equals to 1).

We can roll out our tests to all customers without breaching the 1K tests per week capacity constraint within just four weeks. In the end, we will have the following weekly forecast.

We can also look at the rollout plan we’ve logged throughout our simulations. So, we need to launch the test for randomly selected 35% of customers on the week of 31st March, then for the next 20% of customers next week, followed by 25% and 20% of existing users for the remaining two weeks. After that, we will roll out our project to all existing students.

rollout_plan
# [{'launch_date': '2024-03-31', 'rollout_percent': 0.35},
# {'launch_date': '2024-04-07', 'rollout_percent': 0.2},
# {'launch_date': '2024-04-14', 'rollout_percent': 0.25},
# {'launch_date': '2024-04-21', 'rollout_percent': 0.2}]

So, congratulations. We now have a plan for how to roll out our feature sustainably.

Tracking students’ performance over time

We’ve already done a lot to estimate demand. We’ve leveraged the idea of simulation by imitating the launch of our project a year ago, scaling it and assessing the consequences. So, it’s definitely a simulation example.

However, we mostly used the basic tools you use daily – some Pandas data wrangling and arithmetic operations. In the last part of the article, I would like to show you a bit more complex case where we will need to simulate the process for each customer independently.

Product requirements often change over time, and it happened with our project. You, with a team, decided that it would be even better if you could allow your students to track progress over time (not only once at the very beginning). So, we would like to offer students to go through a performance test after each module (if more than one month has passed since the previous test) or if the student returned to the service after three months of absence.

Now, the criteria for test assignments are pretty tricky. However, we can still use the same approach by looking at the data for the previous year. However, this time, we will need to look at each customer’s behaviour and define at what point they would get a test.

We will take into account both new and existing customers since we want to estimate the effects of follow-up tests on all of them. We don’t need any data before the launch because the first test will be assigned at the next active transaction, and all the history won’t matter. So we can filter it out.

sim_df = df[df.date >= '2023-03-31']

Let’s also define a function that calculates the number of days between two date strings. It will be helpful for us in the implementation.

def days_diff(date1, date2):
    return (datetime.datetime.strptime(date2, '%Y-%m-%d')
        - datetime.datetime.strptime(date1, '%Y-%m-%d')).days

Let’s start with one user and discuss the logic with all the details. First, we will filter events related to this user and convert them into the list of dictionaries. It will be way easier for us to work with such data.

user_id = 4861
user_events = sim_df[sim_df.user_id == user_id]
    .sort_values('date')
    .to_dict('records')

# [{'user_id': 4861, 'date': '2023-04-09', 'module': 'pre-A1', 'lesson_num': 8},
# {'user_id': 4861, 'date': '2023-04-16', 'module': 'pre-A1', 'lesson_num': 9},
# {'user_id': 4861, 'date': '2023-04-23', 'module': 'pre-A1', 'lesson_num': 10},
# {'user_id': 4861, 'date': '2023-04-23', 'module': 'pre-A1', 'lesson_num': 11},
# {'user_id': 4861, 'date': '2023-04-30', 'module': 'pre-A1', 'lesson_num': 12},
# {'user_id': 4861, 'date': '2023-05-07', 'module': 'pre-A1', 'lesson_num': 13}]

To simulate our product logic, we will be processing user events one by one and, at each point, checking whether the customer is eligible for the evaluation.

Let’s discuss what variables we need to maintain to be able to tell whether the customer is eligible for the test or not. For that, let’s recap all the possible cases when a customer might get a test:

  • If there were no previous tests -> we need to know whether they passed a test before.
  • If the customer finished the module and more than one month has passed since the previous test -> we need to know the last test date.
  • If the customer returns after three months -> we need to store the date of the last lesson.

To be able to check all these criteria, we can use only two variables: the last test date (None if there was no test before) and the previous lesson date. Also, we will need to store all the generated tests to calculate them later. Let’s initialise all the variables.

tmp_gen_tests = []
last_test_date = None
last_lesson_date = None

Now, we need to iterate by event and check the criteria.

for rec in user_events:
  pass

Let’s go through all our criteria, starting from the initial test. In this case, last_test_date will be equal to None. It’s important for us to update the last_test_date variable after "assigning" the test.

if last_test_date is None: # initial test
    last_test_date = rec['date']
    # TBD saving the test info

In the case of the finished module, we need to check that it’s the last lesson in the module and that more than 30 days have passed.

if (rec['lesson_num'] == 100) and (days_diff(last_test_date, rec['date']) >= 30): 
    last_test_date = rec['date']
    # TBD saving the test info

The last case is that the customer hasn’t used our service for three months.

if (days_diff(last_lesson_date, rec['date']) >= 30): 
    last_test_date = rec['date']
    # TBD saving the test info

Besides, we need to update the last_lesson_date at each iteration to keep it accurate.

We’ve discussed all the building blocks and are ready to combine them and do simulations for all our customers.

import tqdm
tmp_gen_tests = []

for user_id in tqdm.tqdm(sim_raw_df.user_id.unique()):
    # initialising variables
    last_test_date = None
    last_lesson_date = None

    for rec in sim_raw_df[sim_raw_df.user_id == user_id].to_dict('records'):
        # initial test
        if last_test_date is None: 
            last_test_date = rec['date']
            tmp_gen_tests.append(
                {
                    'user_id': rec['user_id'],
                    'date': rec['date'],
                    'trigger': 'initial test'
                }
            )
        # finish module
        elif (rec['lesson_num'] == 100) and (days_diff(last_test_date, rec['date']) >= 30): 
            last_test_date = rec['date']
            tmp_gen_tests.append(
                {
                    'user_id': rec['user_id'],
                    'date': rec['date'],
                    'trigger': 'finished module'
                })
        # reactivation
        elif (days_diff(last_lesson_date, rec['date']) >= 92):
            last_test_date = rec['date']
            tmp_gen_tests.append(
                {
                    'user_id': rec['user_id'],
                    'date': rec['date'],
                    'trigger': 'reactivation'
                })
        last_lesson_date = rec['date']

Now, we can aggregate this data. Since we are again using the previous year’s data, I will adjust the number by ~80% YoY, as we’ve estimated before.

exist_model_upd_stats_df = exist_model_upd.pivot_table(
    index = 'date', columns = 'trigger', values = 'user_id', 
    aggfunc = 'nunique'
).fillna(0)

exist_model_upd_stats_df = exist_model_upd_stats_df
    .map(lambda x: int(round(x * 1.8)))

We got quite a similar estimation for the initial test. In this case, the "initial test" segment equals the sum of new and existing demand in our previous estimations.

So, looking at other segments is way more interesting since they will be incremental to our previous calculations. We can see around 30–60 cases per week from customers who finished modules starting in May.

There will be almost no cases of reactivation. In our simulation, we got 4 cases per year in total.

Congratulations! Now the case is solved, and we’ve found a nice approach that allows us to make precise estimations without advanced math and with only simulation. You can use similar

You can find the full code for this example on GitHub.

Summary

Let me quickly recap what we’ve discussed today:

  • The main idea of computer simulation is imitation based on your data.
  • In many cases, you can reframe the problem from predicting the future to using the data you already have and simulating the process you’re interested in. So, this approach is quite powerful.
  • In this article, we went through an end-to-end example of scenario estimations. We’ve seen how to structure complex problems and split them into a bunch of more defined ones. We’ve also learned to deal with constraints and plan a gradual rollout.

Thank you a lot for reading this article. If you have any follow-up questions or comments, please leave them in the comments section.

Reference

All the images are produced by the author unless otherwise stated.

The post Practical Computer Simulations for Product Analysts appeared first on Towards Data Science.

]]>