Muhammad Ardi, Author at Towards Data Science https://towardsdatascience.com The world’s leading publication for data science, AI, and ML professionals. Thu, 03 Apr 2025 01:12:37 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Muhammad Ardi, Author at Towards Data Science https://towardsdatascience.com 32 32 The Art of Noise https://towardsdatascience.com/the-art-of-noise/ Thu, 03 Apr 2025 01:12:22 +0000 https://towardsdatascience.com/?p=605395 Understanding and implementing a diffusion model from scratch with PyTorch

The post The Art of Noise appeared first on Towards Data Science.

]]>
Introduction

In my last several articles I talked about generative deep learning algorithms, which mostly are related to text generation tasks. So, I think it would be interesting to switch to generative algorithms for image generation now. We knew that nowadays there have been plenty of deep learning models specialized for generating images out there, such as Autoencoder, Variational Autoencoder (VAE), Generative Adversarial Network (GAN) and Neural Style Transfer (NST). I actually got some of my writings about these topics posted on Medium as well. I provide you the links at the end of this article if you want to read them.

In today’s article, I would like to discuss the so-called diffusion model — one of the most impactful models in the field of deep learning for image generation. The idea of this algorithm was first proposed in the paper titled Deep Unsupervised Learning using Nonequilibrium Thermodynamics written by Sohl-Dickstein et al. back in 2015 [1]. Their framework was then developed further by Ho et al. in 2020 in their paper titled Denoising Diffusion Probabilistic Models [2]. DDPM was later adapted by OpenAI and Google to develop DALLE-2 and Imagen, which we knew that these models have impressive capabilities to generate high-quality images.

How Diffusion Model Works

Generally speaking, diffusion model works by generating image from noise. We can think of it like an artist transforming a splash of paint on a canvas into a beautiful artwork. In order to do so, the diffusion model needs to be trained first. There are two main steps required to be followed to train the model, namely forward diffusion and backward diffusion.

Figure 1. The forward and backward diffusion process [3].

As you can see in the above figure, forward diffusion is a process where Gaussian noise is applied to the original image iteratively. We keep adding the noise until the image is completely unrecognizable, at which point we can say that the image now lies in the latent space. Different from Autoencoders and GANs where the latent space typically has a lower dimension than the original image, the latent space in DDPM maintains the exact same dimensionality as the original one. This noising process follows the principle of a Markov Chain, meaning that the image at timestep t is affected only by timestep t-1. Forward diffusion is considered easy since what we basically do is just adding some noise step by step.

The second training phase is called backward diffusion, which our objective here is to remove the noise little by little until we obtain a clear image. This process follows the principle of the reverse Markov Chain, where the image at timestep t-1 can only be obtained based on the image at timestep t. Such a denoising process is really difficult since we need to guess which pixels are noise and which ones belong to the actual image content. Thus, we need to employ a neural network model to do so.

DDPM uses U-Net as the basis of the deep learning architecture for backward diffusion. However, instead of using the original U-Net model [4], we need to make several modifications to it so that it will be more suitable for our task. Later on, I am going to train this model on the MNIST Handwritten Digit dataset [5], and we will see whether it can generate similar images.

Well, that was pretty much all the fundamental concepts you need to know about diffusion models for now. In the next sections we are going to get even deeper into the details while implementing the algorithm from scratch.


PyTorch Implementation

We are going to start by importing the required modules. In case you’re not yet familiar with the imports below, both torch and torchvision are the libraries we’ll use for preparing the model and the dataset. Meanwhile, matplotlib and tqdm will help us display images and progress bars.

# Codeblock 1
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

As the modules have been imported, the next thing to do is to initialize some config parameters. Look at the Codeblock 2 below for the details.

# Codeblock 2
IMAGE_SIZE     = 28     #(1)
NUM_CHANNELS   = 1      #(2)

BATCH_SIZE     = 2
NUM_EPOCHS     = 10
LEARNING_RATE  = 0.001

NUM_TIMESTEPS  = 1000   #(3)
BETA_START     = 0.0001 #(4)
BETA_END       = 0.02   #(5)
TIME_EMBED_DIM = 32     #(6)
DEVICE = torch.device("cuda" if torch.cuda.is_available else "cpu")  #(7)
DEVICE

# Codeblock 2 Output
device(type='cuda')

At the lines marked with #(1) and #(2) I set IMAGE_SIZE and NUM_CHANNELS to 28 and 1, which these numbers are obtained from the image dimension in the MNIST dataset. The BATCH_SIZE, NUM_EPOCHS, and LEARNING_RATE variables are pretty straightforward, so I don’t think I need to explain them further.

At line #(3), the variable NUM_TIMESTEPS denotes the number of iterations in the forward and backward diffusion process. Timestep 0 is the condition where the image is in its original state (the leftmost image in Figure 1). In this case, since we set this parameter to 1000, timestep number 999 is going to be the condition where the image is completely unrecognizable (the rightmost image in Figure 1). It is important to keep in mind that the choice of the number of timesteps involves a tradeoff between model accuracy and computational cost. If we assign a small value for NUM_TIMESTEPS, the inference time is going to be shorter, yet the resulting image might not be really good since the model has fewer steps to refine the image in the backward diffusion stage. On the other hand, increasing NUM_TIMESTEPS will slow down the inference process, but we can expect the output image to have better quality thanks to the gradual denoising process which results in a more precise reconstruction.

Next, the BETA_START (#(4)) and BETA_END (#(5)) variables are used to control the amount of Gaussian noise added at each timestep, whereas TIME_EMBED_DIM (#(6)) is employed to determine the feature vector length for storing the timestep information. Lastly, at line #(7) I assign “cuda” to the DEVICE variable if Pytorch detects GPU installed in our machine. I highly recommend you run this project on GPU since training a diffusion model is computationally expensive. In addition to the above parameters, the values set for NUM_TIMESTEPS, BETA_START and BETA_END are all adopted directly from the DDPM paper [2].

The complete implementation will be done in several steps: constructing the U-Net model, preparing the dataset, defining noise scheduler for the diffusion process, training, and inference. We are going to discuss each of those stages in the following sub-sections.


The U-Net Architecture: Time Embedding

As I’ve mentioned earlier, the basis of a diffusion model is U-Net. This architecture is used because its output layer is suitable to represent an image, which definitely makes sense since it was initially introduced for image segmentation task at the first place. The following figure shows what the original U-Net architecture looks like.

Figure 2. The original U-Net model proposed in [4].

However, it is necessary to modify this architecture so that it can also take into account the timestep information. Not only that, since we will only use MNIST dataset, we also need to make the model smaller. Just remember the convention in deep learning that simpler models are often more effective for simple tasks.

In the figure below I show you the entire U-Net model that has been modified. Here you can see that the time embedding tensor is injected to the model at every stage, which will later be done by element-wise summation, allowing the model to capture the timestep information. Next, instead of repeating each of the downsampling and the upsampling stages four times like the original U-Net, in this case we will only repeat each of them twice. Additionally, it is worth noting that the stack of downsampling stages is also known as the encoder, whereas the stack of upsampling stages is often called the decoder.

Figure 3. The modified U-Net model for our diffusion task [3].

Now let’s start constructing the architecture by creating a class for generating the time embedding tensor, which the idea is similar to the positional embedding in Transformer. See the Codeblock 3 below for the details.

# Codeblock 3
class TimeEmbedding(nn.Module):
    def forward(self):
        time = torch.arange(NUM_TIMESTEPS, device=DEVICE).reshape(NUM_TIMESTEPS, 1)  #(1)
        print(f"time\t\t: {time.shape}")
          
        i = torch.arange(0, TIME_EMBED_DIM, 2, device=DEVICE)
        denominator = torch.pow(10000, i/TIME_EMBED_DIM)
        print(f"denominator\t: {denominator.shape}")
          
        even_time_embed = torch.sin(time/denominator)  #(1)
        odd_time_embed  = torch.cos(time/denominator)  #(2)
        print(f"even_time_embed\t: {even_time_embed.shape}")
        print(f"odd_time_embed\t: {odd_time_embed.shape}")
          
        stacked = torch.stack([even_time_embed, odd_time_embed], dim=2)  #(3)
        print(f"stacked\t\t: {stacked.shape}")
        time_embed = torch.flatten(stacked, start_dim=1, end_dim=2)  #(4)
        print(f"time_embed\t: {time_embed.shape}")
          
        return time_embed

What we basically do in the above code is to create a tensor of size NUM_TIMESTEPS × TIME_EMBED_DIM (1000×32), where every single row of this tensor will contain the timestep information. Later on, each of the 1000 timesteps will be represented by a feature vector of length 32. The values in the tensor themselves are obtained based on the two equations in Figure 4. In the Codeblock 3 above, these two equations are implemented at line #(1) and #(2), each forming a tensor having the size of 1000×16. Next, these tensors are combined using the code at line #(3) and #(4).

Here I also print out every single step done in the above codeblock so that you can get a better understanding of what is actually being done in the TimeEmbedding class. If you still want more explanation about the above code, feel free to read my previous post about Transformer which you can access through the link at the end of this article. Once you clicked the link, you can just scroll all the way down to the Positional Encoding section.

Figure 4. The sinusoidal positional encoding formula from the Transformer paper [6].

Now let’s check if the TimeEmbedding class works properly using the following testing code. The resulting output shows that it successfully produced a tensor of size 1000×32, which is exactly what we expected earlier.

# Codeblock 4
time_embed_test = TimeEmbedding()
out_test = time_embed_test()

# Codeblock 4 Output
time            : torch.Size([1000, 1])
denominator     : torch.Size([16])
even_time_embed : torch.Size([1000, 16])
odd_time_embed  : torch.Size([1000, 16])
stacked         : torch.Size([1000, 16, 2])
time_embed      : torch.Size([1000, 32])

The U-Net Architecture: DoubleConv

If you take a closer look at the modified architecture, you will see that we actually got lots of repeating patterns, such as the ones highlighted in yellow boxes in the following figure.

Figure 5. The processes done inside the yellow boxes will be implemented in the DoubleConv class [3].

These five yellow boxes share the same structure, where they consist of two convolution layers with the time embedding tensor injected right after the first convolution operation is performed. So, what we are going to do now is to create another class named DoubleConv to reproduce this structure. Look at the Codeblock 5a and 5b below to see how I do that.

# Codeblock 5a
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):  #(1)
        super().__init__()
        
        self.conv_0 = nn.Conv2d(in_channels=in_channels,  #(2)
                                out_channels=out_channels, 
                                kernel_size=3, 
                                bias=False, 
                                padding=1)
        self.bn_0 = nn.BatchNorm2d(num_features=out_channels)  #(3)
        
        self.time_embedding = TimeEmbedding()  #(4)
        self.linear = nn.Linear(in_features=TIME_EMBED_DIM,  #(5)
                                out_features=out_channels)
        
        self.conv_1 = nn.Conv2d(in_channels=out_channels,  #(6)
                                out_channels=out_channels, 
                                kernel_size=3, 
                                bias=False, 
                                padding=1)
        self.bn_1 = nn.BatchNorm2d(num_features=out_channels)  #(7)
        
        self.relu = nn.ReLU(inplace=True)  #(8)

The two inputs of the __init__() method above gives us flexibility to configure the number of input and output channels (#(1)) so that the DoubleConv class can be used to instantiate all the five yellow boxes simply by adjusting its input arguments. As the name suggests, here we initialize two convolution layers (line #(2) and #(6)), each followed by a batch normalization layer and a ReLU activation function. Keep in mind that the two normalization layers need to be initialized separately (line #(3) and #(7)) since each of them has their own trainable normalization parameters. Meanwhile, the ReLU activation function should only be initialized once (#(8)) because it contains no parameters, allowing it to be used multiple times in different parts of the network. At line #(4), we initialize the TimeEmbedding layer we created earlier, which will later be connected to a standard linear layer (#(5)). This linear layer is responsible to adjust the dimension of the time embedding tensor so that the resulting output can be summed with the output from the first convolution layer in an element-wise manner.

Now let’s take a look at the Codeblock 5b below to better understand the flow of the DoubleConv block. Here you can see that the forward() method accepts two inputs: the raw image x and the timestep information t as shown at line #(1). We initially process the image with the first Conv-BN-ReLU sequence (#(2–4)). This Conv-BN-ReLU structure is typically used when working with CNN-based models, even if the illustration does not explicitly show the batch normalization and the ReLU layers. Apart from the image, we then take the t-th timestep information from our embedding tensor of the corresponding image (#(5)) and pass it through the linear layer (#(6)). We still need to expand the dimension of the resulting tensor using the code at line #(7) before performing element-wise summation at line #(8). Finally, we process the resulting tensor with the second Conv-BN-ReLU sequence (#(9–11)).

# Codeblock 5b
    def forward(self, x, t):  #(1)
        print(f'images\t\t\t: {x.size()}')
        print(f'timesteps\t\t: {t.size()}, {t}')
        
        x = self.conv_0(x)  #(2)
        x = self.bn_0(x)    #(3)
        x = self.relu(x)    #(4)
        print(f'\nafter first conv\t: {x.size()}')
        
        time_embed = self.time_embedding()[t]      #(5)
        print(f'\ntime_embed\t\t: {time_embed.size()}')
        
        time_embed = self.linear(time_embed)       #(6)
        print(f'time_embed after linear\t: {time_embed.size()}')
        
        time_embed = time_embed[:, :, None, None]  #(7)
        print(f'time_embed expanded\t: {time_embed.size()}')
        
        x = x + time_embed  #(8)
        print(f'\nafter summation\t\t: {x.size()}')
        
        x = self.conv_1(x)  #(9)
        x = self.bn_1(x)    #(10)
        x = self.relu(x)    #(11)
        print(f'after second conv\t: {x.size()}')
        
        return x

To see if our DoubleConv implementation works properly, we are going to test it with the Codeblock 6 below. Here I want to simulate the very first instance of this block, which corresponds to the leftmost yellow box in Figure 5. To do so, we need to we need to set the in_channels and out_channels parameters to 1 and 64, respectively (#(1)). Next, we initialize two input tensors, namely x_test and t_test. The x_test tensor has the size of 2×1×28×28, representing a batch of two grayscale images having the size of 28×28 (#(2)). Keep in mind that this is just a dummy tensor of random values which will be replaced with the actual images from MNIST dataset later in the training phase. Meanwhile, t_test is a tensor containing the timestep numbers of the corresponding images (#(3)). The values for this tensor are randomly selected between 0 and NUM_TIMESTEPS (1000). Note that the datatype of this tensor must be an integer since the numbers will be used for indexing, as shown at line #(5) back in Codeblock 5b. Lastly, at line #(4) we pass both x_test and t_test tensors to the double_conv_test layer.

By the way, I re-run the previous codeblocks with the print() functions removed prior to running the following code so that the outputs will look neater.

# Codeblock 6
double_conv_test = DoubleConv(in_channels=1, out_channels=64).to(DEVICE)  #(1)

x_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE)  #(2)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)  #(3)

out_test = double_conv_test(x_test, t_test)  #(4)

# Codeblock 6 Output
images                  : torch.Size([2, 1, 28, 28])   #(1)
timesteps               : torch.Size([2]), tensor([468, 304], device='cuda:0')  #(2)

after first conv        : torch.Size([2, 64, 28, 28])  #(3)

time_embed              : torch.Size([2, 32])          #(4)
time_embed after linear : torch.Size([2, 64])
time_embed expanded     : torch.Size([2, 64, 1, 1])    #(5)

after summation         : torch.Size([2, 64, 28, 28])  #(6)
after second conv       : torch.Size([2, 64, 28, 28])  #(7)

The shape of our original input tensors can be seen at lines #(1) and #(2) in the above output. Specifically at line #(2), I also print out the two timesteps that we selected randomly. In this example we assume that each of the two images in the x tensor are already noised with the noise level from 468-th and 304-th timesteps prior to being fed into the network. We can see that the shape of the image tensor x changes to 2×64×28×28 after being passed through the first convolution layer (#(3)). Meanwhile, the size of our time embedding tensor becomes 2×32 (#(4)), which is obtained by extracting rows 468 and 304 from the original embedding of size 1000×32. In order to allow element-wise summation to be performed (#(6)), we need to map the 32-dimensional time embedding vectors into 64 and expand their axes, resulting in a tensor of size 2×64×1×1 (#(5)) so that it can be broadcast to the 2×64×28×28 tensor. After the summation is done, we then pass the tensor through the second convolution layer, at which point the tensor dimension does not change at all (#(7)).


The U-Net Architecture: Encoder

As we have successfully implemented the DoubleConv block, the next step to do is to implement the so-called DownSample block. In Figure 6 below, this corresponds to the parts enclosed in the red box.

Figure 6. The parts of the network highlighted in red are the so-called DownSample blocks [3].

The purpose of a DownSample block is to reduce the spatial dimension of an image, but it is important to note that at the same time it increases the number of channels. In order to achieve this, we can simply stack a DoubleConv block and a maxpooling operation. In this case the pooling uses 2×2 kernel size with the stride of 2, causing the spatial dimension of the image to be twice as small as the input. The implementation of this block can be seen in Codeblock 7 below.

# Codeblock 7
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):  #(1)
        super().__init__()
        
        self.double_conv = DoubleConv(in_channels=in_channels,  #(2)
                                      out_channels=out_channels)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)    #(3)
    
    def forward(self, x, t):  #(4)
        print(f'original\t\t: {x.size()}')
        print(f'timesteps\t\t: {t.size()}, {t}')
        
        convolved = self.double_conv(x, t)   #(5)
        print(f'\nafter double conv\t: {convolved.size()}')
        
        maxpooled = self.maxpool(convolved)  #(6)
        print(f'after pooling\t\t: {maxpooled.size()}')
        
        return convolved, maxpooled          #(7)

Here I set the __init__() method to take number of input and output channels so that we can use it for creating the two DownSample blocks highlighted in Figure 6 without needing to write them in separate classes (#(1)). Next, the DoubleConv and the maxpooling layers are initialized at line #(2) and #(3), respectively. Remember that since the DoubleConv block accepts image x and the corresponding timestep t as the inputs, we also need to set the forward() method of this DownSample block such that it accepts both of them as well (#(4)). The information contained in x and t are then combined as the two tensors are processed by the double_conv layer, which the output is stored in the variable named convolved (#(5)). Afterwards, we now actually perform the downsampling with the maxpooling operation at line #(6), producing a tensor named maxpooled. It is important to note that both the convolved and maxpooled tensors are going to be returned, which is essentially done because we will later bring maxpooled to the next downsampling stage, whereas the convolved tensor will be transferred directly to the upsampling stage in the decoder through skip-connections.

Now let’s test the DownSample class using the Codeblock 8 below. The input tensors used here are exactly the same as the ones in Codeblock 6. Based on the resulting output, we can see that the pooling operation successfully converted the output of the DoubleConv block from 2×64×28×28 (#(1)) to 2×64×14×14 (#(2)), indicating that our DownSample class works properly.

# Codeblock 8
down_sample_test = DownSample(in_channels=1, out_channels=64).to(DEVICE)

x_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)

out_test = down_sample_test(x_test, t_test)

# Codeblock 8 Output
original          : torch.Size([2, 1, 28, 28])
timesteps         : torch.Size([2]), tensor([468, 304], device='cuda:0')

after double conv : torch.Size([2, 64, 28, 28])  #(1)
after pooling     : torch.Size([2, 64, 14, 14])  #(2)

The U-Net Architecture: Decoder

We need to introduce the so-called UpSample block in the decoder, which is responsible for reverting the tensor in the intermediate layers to the original image dimension. In order to maintain a symmetrical structure, the number of UpSample blocks must match that of the DownSample blocks. Look at the Figure 7 below to see where the two UpSample blocks are placed.

Figure 7. The components inside the blue boxes are the so-called UpSample blocks [3].

Since both UpSample blocks are structurally identical, we can just initialize a single class for them, just like the DownSample class we created earlier. Look at the Codeblock 9 below to see how I implement it.

# Codeblock 9
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_transpose = nn.ConvTranspose2d(in_channels=in_channels,  #(1)
                                                 out_channels=out_channels, 
                                                 kernel_size=2, stride=2)  #(2)
        self.double_conv = DoubleConv(in_channels=in_channels,  #(3)
                                      out_channels=out_channels)
        
    def forward(self, x, t, connection):  #(4)
        print(f'original\t\t: {x.size()}')
        print(f'timesteps\t\t: {t.size()}, {t}')
        print(f'connection\t\t: {connection.size()}')
        
        x = self.conv_transpose(x)  #(5)
        print(f'\nafter conv transpose\t: {x.size()}')
        
        x = torch.cat([x, connection], dim=1)  #(6)
        print(f'after concat\t\t: {x.size()}')
        
        x = self.double_conv(x, t)  #(7)
        print(f'after double conv\t: {x.size()}')
        
        return x

In the __init__() method, we use nn.ConvTranspose2d to upsample the spatial dimension (#(1)). Both the kernel size and stride are set to 2 so that the output will be twice as large (#(2)). Next, the DoubleConv block will be employed to reduce the number of channels, while at the same time combining the timestep information from the time embedding tensor (#(3)).

The flow of this UpSample class is a bit more complicated than the DownSample class. If we take a closer look at the architecture, we’ll see that that we also have a skip-connection coming directly from the encoder. Thus, we need the forward() method to accept another argument in addition to the original image x and the timestep t, namely the residual tensor connection (#(4)). The first thing we do inside this method is to process the original image x with the transpose convolution layer (#(5)). In fact, not only upsampling the spatial size, but this layer also reduces the number of channels at the same time. However, the resulting tensor is then directly concatenated with connection in a channel-wise manner (#(6)), causing it to seem like no channel reduction is performed. It is important to know that at this point these two tensors are just concatenated, meaning that the information from the two are not yet combined. We finally feed these concatenated tensors to the double_conv layer (#(7)), allowing them to share information to each other through the learnable parameters inside the convolution layers.

The Codeblock 10 below shows how I test the UpSample class. The size of the tensors to be passed through are set according to the second upsampling block, i.e., the rightmost blue box in Figure 7.

# Codeblock 10
up_sample_test = UpSample(in_channels=128, out_channels=64).to(DEVICE)

x_test = torch.randn((BATCH_SIZE, 128, 14, 14)).to(DEVICE)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)
connection_test = torch.randn((BATCH_SIZE, 64, 28, 28)).to(DEVICE)

out_test = up_sample_test(x_test, t_test, connection_test)

In the resulting output below, if we compare the input tensor (#(1)) with the final tensor shape (#(2)), we can clearly see that the number of channels successfully reduced from 128 to 64, while at the same time the spatial dimension increased from 14×14 to 28×28. This essentially means that our UpSample class is now ready to be used in the main U-Net architecture.

# Codeblock 10 Output
original             : torch.Size([2, 128, 14, 14])   #(1)
timesteps            : torch.Size([2]), tensor([468, 304], device='cuda:0')
connection           : torch.Size([2, 64, 28, 28])

after conv transpose : torch.Size([2, 64, 28, 28])
after concat         : torch.Size([2, 128, 28, 28])
after double conv    : torch.Size([2, 64, 28, 28])    #(2)

The U-Net Architecture: Putting All Components Together

Once all U-Net components have been created, what we are going to do next is to wrap them together into a single class. Look at the Codeblock 11a and 11b below for the details.

# Codeblock 11a
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
      
        self.downsample_0 = DownSample(in_channels=NUM_CHANNELS,  #(1)
                                       out_channels=64)
        self.downsample_1 = DownSample(in_channels=64,            #(2)
                                       out_channels=128)
      
        self.bottleneck   = DoubleConv(in_channels=128,           #(3)
                                       out_channels=256)
      
        self.upsample_0   = UpSample(in_channels=256,             #(4)
                                     out_channels=128)
        self.upsample_1   = UpSample(in_channels=128,             #(5)
                                     out_channels=64)
      
        self.output = nn.Conv2d(in_channels=64,                   #(6)
                                out_channels=NUM_CHANNELS,
                                kernel_size=1)

You can see in the __init__() method above that we initialize two downsampling (#(1–2)) and two upsampling (#(4–5)) blocks, which the number of input and output channels are set according to the architecture shown in the illustration. There are actually two additional components I haven’t explained yet, namely the bottleneck (#(3)) and the output layer (#(6)). The former is essentially just a DoubleConv block, which acts as the main connection between the encoder and the decoder. Look at the Figure 8 below to see which components of the network belong to the bottleneck layer. Next, the output layer is a standard convolution layer which is responsible to turn the 64-channel image produced by the last UpSampling stage into 1-channel only. This operation is done using a kernel of size 1×1, meaning that it combines information across all channels while operating independently at each pixel position.

Figure 8. The bottleneck layer (the lower part of the model) acts as the main bridge between the encoder and the decoder of U-Net [3].

I guess the forward() method of the entire U-Net in the following codeblock is pretty straightforward, as what we essentially do here is pass the tensors from one layer to another — just don’t forget to include the skip connections between the downsampling and upsampling blocks.

# Codeblock 11b
    def forward(self, x, t):  #(1)
        print(f'original\t\t: {x.size()}')
        print(f'timesteps\t\t: {t.size()}, {t}')
            
        convolved_0, maxpooled_0 = self.downsample_0(x, t)
        print(f'\nmaxpooled_0\t\t: {maxpooled_0.size()}')
            
        convolved_1, maxpooled_1 = self.downsample_1(maxpooled_0, t)
        print(f'maxpooled_1\t\t: {maxpooled_1.size()}')
            
        x = self.bottleneck(maxpooled_1, t)
        print(f'after bottleneck\t: {x.size()}')
    
        upsampled_0 = self.upsample_0(x, t, convolved_1)
        print(f'upsampled_0\t\t: {upsampled_0.size()}')
            
        upsampled_1 = self.upsample_1(upsampled_0, t, convolved_0)
        print(f'upsampled_1\t\t: {upsampled_1.size()}')
            
        x = self.output(upsampled_1)
        print(f'final output\t\t: {x.size()}')
            
        return x

Now let’s see whether we have correctly constructed the U-Net class above by running the following testing code.

# Codeblock 12
unet_test = UNet().to(DEVICE)

x_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)

out_test = unet_test(x_test, t_test)

# Codeblock 12 Output
original         : torch.Size([2, 1, 28, 28])   #(1)
timesteps        : torch.Size([2]), tensor([468, 304], device='cuda:0')

maxpooled_0      : torch.Size([2, 64, 14, 14])  #(2)
maxpooled_1      : torch.Size([2, 128, 7, 7])   #(3)
after bottleneck : torch.Size([2, 256, 7, 7])   #(4)
upsampled_0      : torch.Size([2, 128, 14, 14])
upsampled_1      : torch.Size([2, 64, 28, 28])
final output     : torch.Size([2, 1, 28, 28])   #(5)

We can see in the above output that the two downsampling stages successfully converted the original tensor of size 1×28×28 (#(1)) into 64×14×14 (#(2)) and 128×7×7 (#(3)), respectively. This tensor is then passed through the bottleneck layer, causing its number of channels to expand to 256 without changing the spatial dimension (#(4)). Lastly, we upsample the tensor twice before eventually shrinking the number of channels to 1 (#(5)). Based on this output, it looks like our model is working properly. Thus, it is now ready to be trained for our diffusion task.


Dataset Preparation

As we have successfully created the entire U-Net architecture, the next thing to do is to prepare the MNIST Handwritten Digit dataset. Before actually loading it, we need to define the preprocessing steps first using the transforms.Compose() method from Torchvision, as shown at line #(1) in Codeblock 13. There are two things we do here: converting the images into PyTorch tensors which also scales the pixel values from 0–255 to 0–1 (#(2)), and normalize them so that the final pixel values ranging between -1 and 1 (#(3)). Next, we download the dataset using datasets.MNIST(). In this case, we are going to take the images from the training data, hence we use train=True (#(5)). Don’t forget to pass the transform variable we initialized earlier to the transform parameter (transform=transform) so that it will automatically preprocess the images as we load them (#(6)). Lastly, we need to employ DataLoader to load the images from mnist_dataset (#(7)). The arguments I use for the input parameters are intended to randomly pick BATCH_SIZE (2) images from the dataset in each iteration.

# Codeblock 13
transform = transforms.Compose([  #(1)
    transforms.ToTensor(),        #(2)
    transforms.Normalize((0.5,), (0.5,))  #(3)
])

mnist_dataset = datasets.MNIST(   #(4)
    root='./data', 
    train=True,           #(5)
    download=True, 
    transform=transform   #(6)
)

loader = DataLoader(mnist_dataset,  #(7)
                    batch_size=BATCH_SIZE,
                    drop_last=True, 
                    shuffle=True)

In the following codeblock, I try to load a batch of images from the dataset. In every iteration, loader provides both the images and the corresponding labels, hence we need to store them in two separate variables: images and labels.

# Codeblock 14
images, labels = next(iter(loader))

print('images\t\t:', images.shape)
print('labels\t\t:', labels.shape)
print('min value\t:', images.min())
print('max value\t:', images.max())

We can see in the resulting output below that the images tensor has the size of 2×1×28×28 (#(1)), indicating that two grayscale images of size 28×28 have been successfully loaded. Here we can also see that the length of the labels tensor is 2, which matches the number of the loaded images (#(2)). Note that in this case the labels are going to be completely ignored. My plan here is that I just want the model to generate any number it previously seen from the entire training dataset without even knowing what number it actually is. Lastly, this output also shows that the preprocessing works properly, as the pixel values now range between -1 and 1.

# Codeblock 14 Output
images    : torch.Size([2, 1, 28, 28])  #(1)
labels    : torch.Size([2])             #(2)
min value : tensor(-1.)
max value : tensor(1.)

Run the following code if you want to see what the image we just loaded looks like.

# Codeblock 15   
plt.imshow(images[0].squeeze(), cmap='gray')
plt.show()
Figure 9. Output from Codeblock 15 [3].

Noise Scheduler

In this section we are going to talk about how the forward and backward diffusion are performed, which the process essentially involves adding or removing noise little by little at each timestep. It is necessary to know that we basically want a uniform amount of noise across all timesteps, where in the forward diffusion the image should be completely full of noise exactly at timestep 1000, while in the backward diffusion, we have to get the completely clear image at timestep 0. Hence, we need something to control the noise amount for each timestep. Later in this section, I am going to implement a class named NoiseScheduler to do so. — This will probably be the most mathy section of this article, as I’ll display many equations here. But don’t worry about that since we’ll focus on implementing these equations rather than discussing the mathematical derivations.

Now let’s take a look at the equations in Figure 10 which I will implement in the __init__() method of the NoiseScheduler class below.

Figure 10. The equations we need to implement in the __init__() method of the <strong>NoiseScheduler</strong> class [3].
# Codeblock 16a
class NoiseScheduler:
    def __init__(self):
        self.betas = torch.linspace(BETA_START, BETA_END, NUM_TIMESTEPS)  #(1)
        self.alphas = 1. - self.betas
        self.alphas_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cum_prod = torch.sqrt(self.alphas_cum_prod)
        self.sqrt_one_minus_alphas_cum_prod = torch.sqrt(1. - self.alphas_cum_prod)

The above code works by creating multiple sequences of numbers, all of them are basically controlled by BETA_START (0.0001), BETA_END (0.02), and NUM_TIMESTEPS (1000). The first sequence we need to instantiate is the betas itself, which is done using torch.linspace() (#(1)). What it essentially does is that it generates a 1-dimensional tensor of length 1000 starting from 0.0001 to 0.02, where every single element in this tensor corresponds to a single timestep. The interval between each element is uniform, allowing us to generate uniform amount of noise throughout all timesteps as well. With this betas tensor, we then compute alphas, alphas_cum_prod, sqrt_alphas_cum_prod and sqrt_one_minus_alphas_cum_prod based on the four equations in Figure 10. Later on, these tensors will act as the basis of how the noise is generated or removed during the diffusion process.

Diffusion is normally done in a sequential manner. However, the forward diffusion process is deterministic, hence we can derive the original equation into a closed form so that we can obtain the noise at a specific timestep without having to iteratively add noise from the very beginning. The Figure 11 below shows what the closed form of the forward diffusion looks like, where x₀ represents the original image while epsilon (ϵ) denotes an image made up of random Gaussian noise. We can think of this equation as a weighted combination, where we combine the clear image and the noise according to weights determined by the timestep, resulting in an image with a specific amount of noise.

Figure 11. The closed form of the forward diffusion process [3].

The implementation of this equation can be seen in Codeblock 16b. In this forward_diffusion() method, x₀ and ϵ are denoted as original and noise. Here you need to keep in mind that these two input variables are images, whereas sqrt_alphas_cum_prod_t and sqrt_one_minus_alphas_cum_prod_t are scalars. Thus, we need to adjust the shape of these two scalars (#(1) and #(2)) so that the operation at line #(3) can be performed. The noisy_image variable is going to be the output of this function, which I guess the name is self-explanatory.

# Codeblock 16b
    def forward_diffusion(self, original, noise, t):
        sqrt_alphas_cum_prod_t = self.sqrt_alphas_cum_prod[t]
        sqrt_alphas_cum_prod_t = sqrt_alphas_cum_prod_t.to(DEVICE).view(-1, 1, 1, 1)  #(1)
        
        sqrt_one_minus_alphas_cum_prod_t = self.sqrt_one_minus_alphas_cum_prod[t]
        sqrt_one_minus_alphas_cum_prod_t = sqrt_one_minus_alphas_cum_prod_t.to(DEVICE).view(-1, 1, 1, 1)  #(2)
        
        noisy_image = sqrt_alphas_cum_prod_t * original + sqrt_one_minus_alphas_cum_prod_t * noise  #(3)
        
        return noisy_image

Now let’s talk about backward diffusion. In fact, this one is a bit more complicated than the forward diffusion since we need three more equations here. Before I give you these equations, let me show you the implementation first. See the Codeblock 16c below.

# Codeblock 16c
    def backward_diffusion(self, current_image, predicted_noise, t):  #(1)
        denoised_image = (current_image - (self.sqrt_one_minus_alphas_cum_prod[t] * predicted_noise)) / self.sqrt_alphas_cum_prod[t]  #(2)
        denoised_image = 2 * (denoised_image - denoised_image.min()) / (denoised_image.max() - denoised_image.min()) - 1  #(3)
        
        current_prediction = current_image - ((self.betas[t] * predicted_noise) / (self.sqrt_one_minus_alphas_cum_prod[t]))  #(4)
        current_prediction = current_prediction / torch.sqrt(self.alphas[t])  #(5)
        
        if t == 0:  #(6)
            return current_prediction, denoised_image
        
        else:
            variance = (1 - self.alphas_cum_prod[t-1]) / (1. - self.alphas_cum_prod[t])  #(7)
            variance = variance * self.betas[t]  #(8)
            sigma = variance ** 0.5
            z = torch.randn(current_image.shape).to(DEVICE)
            current_prediction = current_prediction + sigma*z
            
            return current_prediction, denoised_image

Later in the inference phase, the backward_diffusion() method will be called inside a loop that iterates NUM_TIMESTEPS (1000) times, starting from t = 999, continued with t = 998, and so on all the way to t = 0. This function is responsible to remove the noise from the image iteratively based on the current_image (the image produced by the previous denoising step), the predicted_noise (the noise predicted by U-Net in the previous step), and the timestep information t (#(1)). In each iteration, noise removal is done using the equation shown in Figure 12, which in Codeblock 16c, this corresponds to lines #(4-5).

Figure 12. The equation used for removing noise from the image [3].

As long as we haven’t reached t = 0, we will compute the variance based on the equation in Figure 13 (#(7–8)). This variance will then be used to introduce another controlled noise to simulate the stochasticity in the backward diffusion process since the noise removal equation in Figure 12 is a deterministic approximation. This is essentially also the reason that we don’t calculate the variance once we reached t = 0 (#(6)) since we no longer need to add more noise as the image is completely clear already.

Figure 13. The equation used to calculate variance for introducing controlled noise [3].

Different from current_prediction which aims to estimate the image of the previous timestep (xₜ₋₁), the objective of the denoised_image tensor is to reconstruct the original image (x₀). Thanks to these different objectives, we need a separate equation to compute denoised_image, which can be seen in Figure 14 below. The implementation of the equation itself is written at line #(2–3).

Figure 14. The equation for reconstructing the original image [3].

Now let’s test the NoiseScheduler class we created above. In the following codeblock, I instantiate a NoiseScheduler object and print out the attributes associated with it, which are all computed using the equation in Figure 10 based on the values stored in the betas attribute. Remember that the actual length of these tensors is NUM_TIMESTEPS (1000), but here I only print out the first 6 elements.

# Codeblock 17
noise_scheduler = NoiseScheduler()

print(f'betas\t\t\t\t: {noise_scheduler.betas[:6]}')
print(f'alphas\t\t\t\t: {noise_scheduler.alphas[:6]}')
print(f'alphas_cum_prod\t\t\t: {noise_scheduler.alphas_cum_prod[:6]}')
print(f'sqrt_alphas_cum_prod\t\t: {noise_scheduler.sqrt_alphas_cum_prod[:6]}')
print(f'sqrt_one_minus_alphas_cum_prod\t: {noise_scheduler.sqrt_one_minus_alphas_cum_prod[:6]}')

# Codeblock 17 Output
betas                          : tensor([1.0000e-04, 1.1992e-04, 1.3984e-04, 1.5976e-04, 1.7968e-04, 1.9960e-04])
alphas                         : tensor([0.9999, 0.9999, 0.9999, 0.9998, 0.9998, 0.9998])
alphas_cum_prod                : tensor([0.9999, 0.9998, 0.9996, 0.9995, 0.9993, 0.9991])
sqrt_alphas_cum_prod           : tensor([0.9999, 0.9999, 0.9998, 0.9997, 0.9997, 0.9996])
sqrt_one_minus_alphas_cum_prod : tensor([0.0100, 0.0148, 0.0190, 0.0228, 0.0264, 0.0300])

The above output indicates that our __init__() method works as expected. Next, we are going to test the forward_diffusion() method. If you go back to Figure 16b, you will see that forward_diffusion() accepts three inputs: original image, noise image and the timestep number. Let’s just use the image from the MNIST dataset we loaded earlier for the first input (#(1)) and a random Gaussian noise of the exact same size for the second one (#(2)). Run the Codeblock 18 below to see what these two images look like.

# Codeblock 18
image = images[0]  #(1)
noise = torch.randn_like(image)  #(2)

plt.imshow(image.squeeze(), cmap='gray')
plt.show()
plt.imshow(noise.squeeze(), cmap='gray')
plt.show()
Figure 15. The two images to be used as the original (left) and the noise image (right). The one on the left is the same image I showed earlier in Figure 9 [3].

As we already got the image and the noise ready, what we need to do afterwards is to pass them to the forward_diffusion() method alongside the t. I actually tried to run the Codeblock 19 below multiple times with t = 50, 100, 150, and so on up to t = 300. You can see in Figure 16 that the image becomes less clear as the parameter increases. In this case, the image is going to be completely filled by noise when the t is set to 999.

# Codeblock 19
noisy_image_test = noise_scheduler.forward_diffusion(image.to(DEVICE), noise.to(DEVICE), t=50)

plt.imshow(noisy_image_test[0].squeeze().cpu(), cmap='gray')
plt.show()
Figure 16. The result of the forward diffusion process at t=50, 100, 150, and so on until t=300 [3].

Unfortunately, we cannot test the backward_diffusion() method since this process requires us to have our U-Net model trained. So, let’s just skip this part for now. I’ll show you how we can actually use this function later in the inference phase.


Training

As the U-Net model, MNIST dataset, and the noise scheduler are ready, we can now prepare a function for training. Before we do that, I instantiate the model and the noise scheduler in Codeblock 20 below.

# Codeblock 20
model = UNet().to(DEVICE)
noise_scheduler = NoiseScheduler()

The entire training procedure is implemented in the train() function shown in Codeblock 21. Before doing anything, we first initialize the optimizer and the loss function, which in this case we use Adam and MSE, respectively (#(1–2)). What we basically want to do here is to train the model such that it will be able to predict the noise contained in the input image, which later on, the predicted noise will be used as the basis of the denoising process in the backward diffusion stage. To actually train the model, we first need to perform forward diffusion using the code at line #(6). This noising process will be done on the images tensor (#(3)) using the random noise generated at line #(4). Next, we take random number somewhere between 0 and NUM_TIMESTEPS (1000) for the t (#(5)), which is essentially done because we want our model to see images of varying noise levels as an approach to improve generalization. As the noisy images have been generated, we then pass it through the U-Net model alongside the chosen t (#(7)). The input t here is useful for the model as it indicates the current noise level in the image. Lastly, the loss function we initialized earlier is responsible to compute the difference between the actual noise and the predicted noise from the original image (#(8)). So, the objective of this training is basically to make the predicted noise as similar as possible to the noise we generated at line #(4).

# Codeblock 21
def train():
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)  #(1)
    loss_function = nn.MSELoss()  #(2)
    losses = []
    
    for epoch in range(NUM_EPOCHS):
        print(f'Epoch no {epoch}')
        
        for images, _ in tqdm(loader):
            
            optimizer.zero_grad()

            images = images.float().to(DEVICE)  #(3)
            noise = torch.randn_like(images)  #(4)
            t = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,))  #(5)

            noisy_images = noise_scheduler.forward_diffusion(images, noise, t).to(DEVICE)  #(6)
            predicted_noise = model(noisy_images, t)  #(7)
            loss = loss_function(predicted_noise, noise)  #(8)
            
            losses.append(loss.item())
            loss.backward()
            optimizer.step()

    return losses

Now let’s run the above training function using the codeblock below. Sit back and relax while waiting the training completes. In my case, I used Kaggle Notebook with Nvidia GPU P100 turned on, and it took around 45 minutes to finish.

# Codeblock 22
losses = train()

If we take a look at the loss graph, it seems like our model learned pretty well as the value is generally decreasing over time with a rapid drop at early stages and a more stable (yet still decreasing) trend in the later stages. So, I think we can expect good results later in the inference phase.

# Codeblock 23
plt.plot(losses)
Figure 17. How the loss value decreases as the training goes [3].

Inference

At this point we have already got our model trained, so we can now perform inference on it. Look at the Codeblock 24 below to see how I implement the inference() function.

# Codeblock 24
def inference():

    denoised_images = []  #(1)
    
    with torch.no_grad():  #(2)
        current_prediction = torch.randn((64, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE)  #(3)
        
        for i in tqdm(reversed(range(NUM_TIMESTEPS))):  #(4)
            predicted_noise = model(current_prediction, torch.as_tensor(i).unsqueeze(0))  #(5)
            current_prediction, denoised_image = noise_scheduler.backward_diffusion(current_prediction, predicted_noise, torch.as_tensor(i))  #(6)

            if i%100 == 0:  #(7)
                denoised_images.append(denoised_image)
            
        return denoised_images

At the line marked with #(1) I initialize an empty list which will be used to store the denoising result every 100 timesteps (#(7)). This will later allow us to see how the backward diffusion goes. The actual inference process is encapsulated inside torch.no_grad() (#(2)). Remember that in diffusion models we generate images from a completely random noise, which we assume that these images are initially at t = 999. To implement this, we can simply use torch.randn() as shown at line #(3). Here we initialize a tensor of size 64×1×28×28, indicating that we are about to generate 64 images simultaneously. Next, we write a for loop that iterates backwards starting from 999 to 0 (#(4)). Inside this loop, we feed the current image and the timestep as the input for the trained U-Net and let it predict the noise (#(5)). The actual backward diffusion is then performed at line #(6). At the end of the iteration, we should get new images similar to the ones we have in our dataset. Now let’s call the inference() function in the following codeblock.

# Codeblock 25
denoised_images = inference()

As the inference completed, we can now see what the resulting images look like. The Codeblock 26 below is used to display the first 42 images we just generated.

# Codeblock 26
fig, axes = plt.subplots(ncols=7, nrows=6, figsize=(10, 8))

counter = 0

for i in range(6):
    for j in range(7):
        axes[i,j].imshow(denoised_images[-1][counter].squeeze().detach().cpu().numpy(), cmap='gray')  #(1)
        axes[i,j].get_xaxis().set_visible(False)
        axes[i,j].get_yaxis().set_visible(False)
        counter += 1

plt.show()
Figure 18. The images generated by the diffusion model trained on the MNIST Handwritten Digit dataset [3].

If we take a look at the above codeblock, you can see that the indexer of [-1] at line #(1) indicates that we only display the images from the last iteration (which corresponds to timestep 0). This is the reason that the images you see in Figure 18 are all free from noise. I do acknowledge that this might not be the best of a result since not all the generated images are valid digit numbers. — But hey, this instead indicates that these images are not merely duplicates from the original dataset.

Here we can also visualize the backward diffusion process using the Codeblock 27 below. You can see in the resulting output in Figure 19 that we initially start from a complete random noise, which gradually disappears as we move to the right.

# Codeblock 27
fig, axes = plt.subplots(ncols=10, figsize=(24, 8))

sample_no = 0
timestep_no = 0

for i in range(10):
    axes[i].imshow(denoised_images[timestep_no][sample_no].squeeze().detach().cpu().numpy(), cmap='gray')
    axes[i].get_xaxis().set_visible(False)
    axes[i].get_yaxis().set_visible(False)
    timestep_no += 1

plt.show()
Figure 19. What the image looks like at timestep 900, 800, 700 and so on until timestep 0 [3].

Ending

There are plenty of directions you can go from here. First, you might probably need to tweak the parameter configurations in Codeblock 2 if you want better results. Second, it is also possible to modify the U-Net model by implementing attention layers in addition to the stack of convolution layers we used in the downsampling and the upsampling stages. This does not guarantee you to obtain better results especially for a simple dataset like this, but it’s definitely worth trying. Third, you can also try to use a more complex dataset if you want to challenge yourself.

When it comes to practical applications, there are actually lots of things you can do with diffusion models. The simplest one might be for data augmentation. With diffusion model, we can easily generate new images from a specific data distribution. For example, suppose we are working on an image classification project, but the number of images in the classes are imbalanced. To address this problem, it is possible for us to take the images from the minority class and feed them into a diffusion model. By doing so, we can ask the trained diffusion model to generate a number of samples from that class as many as we want.

And well, that’s pretty much everything about the theory and the implementation of diffusion model. Thanks for reading, I hope you learn something new today!

You can access the code used in this project through this link. Here are also the links to my previous articles about Autoencoder, Variational Autoencoder (VAE), Neural Style Transfer (NST), and Transformer.


References

[1] Jascha Sohl-Dickstein et al. Deep Unsupervised Learning using Nonequilibrium Thermodynamics. Arxiv. https://arxiv.org/pdf/1503.03585 [Accessed December 27, 2024].

[2] Jonathan Ho et al. Denoising Diffusion Probabilistic Models. Arxiv. https://arxiv.org/pdf/2006.11239 [Accessed December 27, 2024].

[3] Image created originally by author.

[4] Olaf Ronneberger et al. U-Net: Convolutional Networks for Biomedical
 Image Segmentation. Arxiv. https://arxiv.org/pdf/1505.04597 [Accessed December 27, 2024].

[5] Yann LeCun et al. The MNIST Database of Handwritten Digits. https://yann.lecun.com/exdb/mnist/ [Accessed December 30, 2024] (Creative Commons Attribution-Share Alike 3.0 license).

[6] Ashish Vaswani et al. Attention Is All You Need. Arxiv. https://arxiv.org/pdf/1706.03762 [Accessed September 29, 2024].

The post The Art of Noise appeared first on Towards Data Science.

]]>
Image Captioning, Transformer Mode On https://towardsdatascience.com/image-captioning-transformer-mode-on/ Fri, 07 Mar 2025 18:47:02 +0000 https://towardsdatascience.com/?p=599180 Implementing CPTR (CaPtion TransformeR) from scratch with PyTorch

The post Image Captioning, Transformer Mode On appeared first on Towards Data Science.

]]>
Introduction

In my previous article, I discussed one of the earliest Deep Learning approaches for image captioning. If you’re interested in reading it, you can find the link to that article at the end of this one.

Today, I would like to talk about Image Captioning again, but this time with the more advanced neural network architecture. The deep learning I am going to talk about is the one proposed in the paper titled “CPTR: Full Transformer Network for Image Captioning,” written by Liu et al. back in 2021 [1]. Specifically, here I will reproduce the model proposed in the paper and explain the underlying theory behind the architecture. However, keep in mind that I won’t actually demonstrate the training process since I only want to focus on the model architecture.

The idea behind CPTR

In fact, the main idea of the CPTR architecture is exactly the same as the earlier image captioning model, as both use the encoder-decoder structure. Previously, in the paper titled “Show and Tell: A Neural Image Caption Generator” [2], the models used are GoogLeNet (a.k.a. Inception V1) and LSTM for the two components, respectively. The illustration of the model proposed in the Show and Tell paper is shown in the following figure.

Figure 1. The neural network architecture for image captioning proposed in the Show and Tell paper [2].

Despite having the same encoder-decoder structure, what makes CPTR different from the previous approach is the basis of the encoder and the decoder themselves. In CPTR, we combine the encoder part of the ViT (Vision Transformer) model with the decoder part of the original Transformer model. The use of transformer-based architecture for both components is essentially where the name CPTR comes from: CaPtion TransformeR.

Note that the discussions in this article are going to be highly related to ViT and Transformer, so I highly recommend you read my previous article about these two topics if you’re not yet familiar with them. You can find the links at the end of this article.

Figure 2 shows what the original ViT architecture looks like. Everything inside the green box is the encoder part of the architecture to be adopted as the CPTR encoder.

Figure 2. The Vision Transformer (ViT) architecture [3].

Next, Figure 3 displays the original Transformer architecture. The components enclosed in the blue box are the layers that we are going to implement in the CPTR decoder.

Figure 3. The original Transformer architecture [4].

If we combine the components inside the green and blue boxes above, we are going to obtain the architecture shown in Figure 4 below. This is exactly what the CPTR model we are going to implement looks like. The idea here is that the ViT Encoder (green) works by encoding the input image into a specific tensor representation which will then be used as the basis of the Transformer Decoder (blue) to generate the corresponding caption.

Figure 4. The CPTR architecture [5].

That’s pretty much everything you need to know for now. I’ll explain more about the details as we go through the implementation.

Module imports & parameter configuration

As always, the first thing we need to do in the code is to import the required modules. In this case, we only import torch and torch.nn since we are about to implement the model from scratch.

# Codeblock 1
import torch
import torch.nn as nn

Next, we are going to initialize some parameters in Codeblock 2. If you have read my previous article about image captioning with GoogLeNet and LSTM, you’ll notice that here, we got a lot more parameters to initialize. In this article, I want to reproduce the CPTR model as closely as possible to the original one, so the parameters mentioned in the paper will be used in this implementation.

# Codeblock 2
BATCH_SIZE         = 1              #(1)

IMAGE_SIZE         = 384            #(2)
IN_CHANNELS        = 3              #(3)

SEQ_LENGTH         = 30             #(4)
VOCAB_SIZE         = 10000          #(5)

EMBED_DIM          = 768            #(6)
PATCH_SIZE         = 16             #(7)
NUM_PATCHES        = (IMAGE_SIZE//PATCH_SIZE) ** 2  #(8)
NUM_ENCODER_BLOCKS = 12             #(9)
NUM_DECODER_BLOCKS = 4              #(10)
NUM_HEADS          = 12             #(11)
HIDDEN_DIM         = EMBED_DIM * 4  #(12)
DROP_PROB          = 0.1            #(13)

The first parameter I want to explain is the BATCH_SIZE, which is written at the line marked with #(1). The number assigned to this variable is not quite important in our case since we are not actually going to train this model. This parameter is set to 1 because, by default, PyTorch treats input tensors as a batch of samples. Here I assume that we only have a single sample in a batch. 

Next, remember that in the case of image captioning we are dealing with images and texts simultaneously. This essentially means that we need to set the parameters for the two. It is mentioned in the paper that the model accepts an RGB image of size 384×384 for the encoder input. Hence, we assign the values for IMAGE_SIZE and IN_CHANNELS variables based on this information (#(2) and #(3)). On the other hand, the paper does not mention the parameters for the captions. So, here I assume that the length of the caption is no more than 30 words (#(4)), with the vocabulary size estimated at 10000 unique words (#(5)).

The remaining parameters are related to the model configuration. Here we set the EMBED_DIM variable to 768 (#(6)). In the encoder side, this number indicates the length of the feature vector that represents each 16×16 image patch (#(7)). The same concept also applies to the decoder side, but in that case the feature vector will represent a single word in the caption. Talking more specifically about the PATCH_SIZE parameter, we are going to use the value to compute the total number of patches in the input image. Since the image has the size of 384×384, there will be 576 patches in total (#(8)).

When it comes to using an encoder-decoder architecture, it is possible to specify the number of encoder and decoder blocks to be used. Using more blocks typically allows the model to perform better in terms of the accuracy, yet in return, it will require more computational power. The authors of this paper decided to stack 12 encoder blocks (#(9)) and 4 decoder blocks (#(10)). Next, since CPTR is a transformer-based model, it is necessary to specify the number of attention heads within the attention blocks inside the encoders and the decoders, which in this case authors use 12 attention heads (#(11)). The value for the HIDDEN_DIM parameter is not mentioned anywhere in the paper. However, according to the ViT and the Transformer paper, this parameter is configured to be 4 times larger than EMBED_DIM (#(12)). The dropout rate is not mentioned in the paper either. Hence, I arbitrarily set DROP_PROB to 0.1 (#(13)).

Encoder

As the modules and parameters have been set up, now that we will get into the encoder part of the network. In this section we are going to implement and explain every single component inside the green box in Figure 4 one by one.

Patch embedding

Figure 5. Dividing the input image into patches and converting them into vectors [5].

You can see in Figure 5 above that the first step to be done is dividing the input image into patches. This is essentially done because instead of focusing on local patterns like CNNs, ViT captures global context by learning the relationships between these patches. We can model this process with the Patcher class shown in the Codeblock 3 below. For the sake of simplicity, here I also include the process inside the patch embedding block within the same class.

# Codeblock 3
class Patcher(nn.Module):
   def __init__(self):
       super().__init__()

       #(1)
       self.unfold = nn.Unfold(kernel_size=PATCH_SIZE, stride=PATCH_SIZE)

       #(2)
       self.linear_projection = nn.Linear(in_features=IN_CHANNELS*PATCH_SIZE*PATCH_SIZE,
                                          out_features=EMBED_DIM)
      
   def forward(self, images):
       print(f'images\t\t: {images.size()}')
       images = self.unfold(images)  #(3)
       print(f'after unfold\t: {images.size()}')
      
       images = images.permute(0, 2, 1)  #(4)
       print(f'after permute\t: {images.size()}')
      
       features = self.linear_projection(images)  #(5)
       print(f'after lin proj\t: {features.size()}')
      
       return features

The patching itself is done using the nn.Unfold layer (#(1)). Here we need to set both the kernel_size and stride parameters to PATCH_SIZE (16) so that the resulting patches do not overlap with each other. This layer also automatically flattens these patches once it is applied to the input image. Meanwhile, the nn.Linear layer (#(2)) is employed to perform linear projection, i.e., the process done by the patch embedding block. By setting the out_features parameter to EMBED_DIM, this layer will map every single flattened patch into a feature vector of length 768.

The entire process should make more sense once you read the forward() method. You can see at line #(3) in the same codeblock that the input image is directly processed by the unfold layer. Next, we need to process the resulting tensor with the permute() method (#(4)) to swap the first and the second axis before feeding it to the linear_projection layer (#(5)). Additionally, here I also print out the tensor dimension after each layer so that you can better understand the transformation made at each step.

In order to check if our Patcher class works properly, we can just pass a dummy tensor through the network. Look at the Codeblock 4 below to see how I do it.

# Codeblock 4
patcher  = Patcher()

images   = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = patcher(images)
# Codeblock 4 Output
images         : torch.Size([1, 3, 384, 384])
after unfold   : torch.Size([1, 768, 576])  #(1)
after permute  : torch.Size([1, 576, 768])  #(2)
after lin proj : torch.Size([1, 576, 768])  #(3)

The tensor I passed above represents an RGB image of size 384×384. Here we can see that after the unfold operation is performed, the tensor dimension changed to 1×768×576 (#(1)), denoting the flattened 3×16×16 patch for each of the 576 patches. Unfortunately, this output shape does not match what we need. Remember that in ViT, we perceive image patches as a sequence, so we need to swap the 1st and 2nd axes because typically, the 1st dimension of a tensor represents the temporal axis, while the 2nd one represents the feature vector of each timestep. As the permute() operation is performed, our tensor is now having the dimension of 1×576×768 (#(2)). Lastly, we pass this tensor through the linear projection layer, which the resulting tensor shape remains the same since we set the EMBED_DIM parameter to the same size (768) (#(3)). Despite having the same dimension, the information contained in the final tensor should be richer thanks to the transformation applied by the trainable weights of the linear projection layer.

Learnable positional embedding

Figure 6. Injecting the learnable positional embeddings into the embedded image patches [5].

After the input image has successfully been converted into a sequence of patches, the next thing to do is to inject the so-called positional embedding tensor. This is essentially done because a transformer without positional embedding is permutation-invariant, meaning that it treats the input sequence as if their order does not matter. Interestingly, since an image is not a literal sequence, we should set the positional embedding to be learnable such that it will be able to somewhat reorder the patch sequence that it thinks works best in representing the spatial information. However, keep in mind that the term “reordering” here does not mean that we physically rearrange the sequence. Rather, it does so by adjusting the embedding weights.

The implementation is pretty simple. All we need to do is just to initialize a tensor using nn.Parameter which the dimension is set to match with the output from the Patcher model, i.e., 576×768. Also, don’t forget to write requires_grad=True just to ensure that the tensor is trainable. Look at the Codeblock 5 below for the details.

# Codeblock 5
class LearnableEmbedding(nn.Module):
   def __init__(self):
       super().__init__()
       self.learnable_embedding = nn.Parameter(torch.randn(size=(NUM_PATCHES, EMBED_DIM)),
                                               requires_grad=True)
      
   def forward(self):
       pos_embed = self.learnable_embedding
       print(f'learnable embedding\t: {pos_embed.size()}')
      
       return pos_embed

Now let’s run the following codeblock to see whether our LearnableEmbedding class works properly. You can see in the printed output that it successfully created the positional embedding tensor as expected.

# Codeblock 6
learnable_embedding = LearnableEmbedding()

pos_embed = learnable_embedding()
# Codeblock 6 Output
learnable embedding : torch.Size([576, 768])

The main encoder block

Figure 7. The main encoder block [5].

The next thing we are going to do is to construct the main encoder block displayed in the Figure 7 above. Here you can see that this block consists of several sub-components, namely self-attention, layer norm, FFN (Feed-Forward Network), and another layer norm. The Codeblock 7a below shows how I initialize these layers inside the __init__() method of the EncoderBlock class.

# Codeblock 7a
class EncoderBlock(nn.Module):
   def __init__(self):
       super().__init__()
      
       #(1)
       self.self_attention = nn.MultiheadAttention(embed_dim=EMBED_DIM,
                                                   num_heads=NUM_HEADS,
                                                   batch_first=True,  #(2)
                                                   dropout=DROP_PROB)
      
       self.layer_norm_0 = nn.LayerNorm(EMBED_DIM)  #(3)
      
       self.ffn = nn.Sequential(  #(4)
           nn.Linear(in_features=EMBED_DIM, out_features=HIDDEN_DIM),
           nn.GELU(),
           nn.Dropout(p=DROP_PROB),
           nn.Linear(in_features=HIDDEN_DIM, out_features=EMBED_DIM),
       )
      
       self.layer_norm_1 = nn.LayerNorm(EMBED_DIM)  #(5)

I’ve previously mentioned that the idea of ViT is to capture the relationships between patches within an image. This process is done by the multihead attention layer I initialize at line #(1) in the above codeblock. One thing to keep in mind here is that we need to set the batch_first parameter to True (#(2)). This is essentially done so that the attention layer will be compatible with our tensor shape, in which the batch dimension (batch_size) is at the 0th axis of the tensor. Next, the two layer normalization layers need to be initialized separately, as shown at line #(3) and #(5). Lastly, we initialize the FFN block at line #(4), which the layers stacked using nn.Sequential follows the structure defined in the following equation.

Figure 8. The operations done inside the FFN block [1].

As the __init__() method is complete, we will now continue with the forward() method. Let’s take a look at the Codeblock 7b below.

# Codeblock 7b
   def forward(self, features):  #(1)
      
       residual = features  #(2)
       print(f'features & residual\t: {residual.size()}')
      
       #(3)
       features, self_attn_weights = self.self_attention(query=features,
                                                         key=features,
                                                         value=features)
       print(f'after self attention\t: {features.size()}')
       print(f"self attn weights\t: {self_attn_weights.shape}")
      
       features = self.layer_norm_0(features + residual)  #(4)
       print(f'after norm\t\t: {features.size()}')
      

       residual = features
       print(f'\nfeatures & residual\t: {residual.size()}')
      
       features = self.ffn(features)  #(5)
       print(f'after ffn\t\t: {features.size()}')
      
       features = self.layer_norm_1(features + residual)
       print(f'after norm\t\t: {features.size()}')
      
       return features

Here you can see that the input tensor is named features (#(1)). I name it this way because the input of the EncoderBlock is the image that has already been processed with Patcher and LearnableEmbedding, instead of a raw image. Before doing anything, notice in the encoder block that there is a branch separated from the main flow which then returns back to the normalization layer. This branch is commonly known as a residual connection. To implement this, we need to store the original input tensor to the residual variable as I demonstrate at line #(2). As the input tensor has been copied, now we are ready to process the original input with the multihead attention layer (#(3)). Since this is a self-attention (not a cross-attention), the query, key, and value inputs for this layer are all derived from the features tensor. Next, the layer normalization operation is then performed at line #(4), which the input for this layer already contains information from the attention block as well as the residual connection. The remaining steps are basically the same as what I just explained, except that here we replace the self-attention block with FFN (#(5)).

In the following codeblock, I’ll test the EncoderBlock class by passing a dummy tensor of size 1×576×768, simulating an output tensor from the previous operations.

# Codeblock 8
encoder_block = EncoderBlock()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)
features = encoder_block(features)

Below is what the tensor dimension looks like throughout the entire process inside the model.

# Codeblock 8 Output
features & residual  : torch.Size([1, 576, 768])  #(1)
after self attention : torch.Size([1, 576, 768])
self attn weights    : torch.Size([1, 576, 576])  #(2)
after norm           : torch.Size([1, 576, 768])

features & residual  : torch.Size([1, 576, 768])
after ffn            : torch.Size([1, 576, 768])  #(3)
after norm           : torch.Size([1, 576, 768])  #(4)

Here you can see that the final output tensor (#(4)) has the same size as the input (#(1)), allowing us to stack multiple encoder blocks without having to worry about messing up the tensor dimensions. Not only that, the size of the tensor also appears to be unchanged from the beginning all the way to the last layer. In fact, there are actually lots of transformations performed inside the attention block, but we just can’t see it since the entire process is done internally by the nn.MultiheadAttention layer. One of the tensors produced in the layer that we can observe is the attention weight (#(2)). This weight matrix, which has the size of 576×576, is responsible for storing information regarding the relationships between one patch and every other patch in the image. Furthermore, changes in tensor dimension actually also happened inside the FFN layer. The feature vector of each patch which has the initial length of 768 changed to 3072 and immediately shrunk back to 768 again (#(3)). However, this transformation is not printed since the process is wrapped with nn.Sequential back at line #(4) in Codeblock 7a.

ViT encoder

Figure 9. The entire ViT Encoder in the CPTR architecture [5].

As we have finished implementing all encoder components, now that we will assemble them to construct the actual ViT Encoder. We are going to do it in the Encoder class in Codeblock 9.

# Codeblock 9
class Encoder(nn.Module):
   def __init__(self):
       super().__init__()
       self.patcher = Patcher()  #(1)
       self.learnable_embedding = LearnableEmbedding()  #(2)

       #(3)
       self.encoder_blocks = nn.ModuleList(EncoderBlock() for _ in range(NUM_ENCODER_BLOCKS))
  
   def forward(self, images):  #(4)
       print(f'images\t\t\t: {images.size()}')
      
       features = self.patcher(images)  #(5)
       print(f'after patcher\t\t: {features.size()}')
      
       features = features + self.learnable_embedding()  #(6)
       print(f'after learn embed\t: {features.size()}')
      
       for i, encoder_block in enumerate(self.encoder_blocks):
           features = encoder_block(features)  #(7)
           print(f"after encoder block #{i}\t: {features.shape}")

       return features

Inside the __init__() method, what we need to do is to initialize all components we created earlier, i.e., Patcher (#(1)), LearnableEmbedding (#(2)), and EncoderBlock (#(3)). In this case, the EncoderBlock is initialized inside nn.ModuleList since we want to repeat it NUM_ENCODER_BLOCKS (12) times. To the forward() method, it initially works by accepting raw image as the input (#(4)). We then process it with the patcher layer (#(5)) to divide the image into small patches and transform them with the linear projection operation. The learnable positional embedding tensor is then injected into the resulting output by element-wise addition (#(6)). Lastly, we pass it into the 12 encoder blocks sequentially with a simple for loop (#(7)).

Now, in Codeblock 10, I am going to pass a dummy image through the entire encoder. Note that since I want to focus on the flow of this Encoder class, I re-run the previous classes we created earlier with the print() functions commented out so that the outputs will look neat.

# Codeblock 10
encoder = Encoder()

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = encoder(images)

And below is what the flow of the tensor looks like. Here, we can see that our dummy input image successfully passed through all layers in the network, including the encoder blocks that we repeat 12 times. The resulting output tensor is now context-aware, meaning that it already contains information about the relationships between patches within the image. Therefore, this tensor is now ready to be processed further with the decoder, which will later be discussed in the subsequent section.

# Codeblock 10 Output
images                  : torch.Size([1, 3, 384, 384])
after patcher           : torch.Size([1, 576, 768])
after learn embed       : torch.Size([1, 576, 768])
after encoder block #0  : torch.Size([1, 576, 768])
after encoder block #1  : torch.Size([1, 576, 768])
after encoder block #2  : torch.Size([1, 576, 768])
after encoder block #3  : torch.Size([1, 576, 768])
after encoder block #4  : torch.Size([1, 576, 768])
after encoder block #5  : torch.Size([1, 576, 768])
after encoder block #6  : torch.Size([1, 576, 768])
after encoder block #7  : torch.Size([1, 576, 768])
after encoder block #8  : torch.Size([1, 576, 768])
after encoder block #9  : torch.Size([1, 576, 768])
after encoder block #10 : torch.Size([1, 576, 768])
after encoder block #11 : torch.Size([1, 576, 768])

ViT encoder (alternative)

I want to show you something before we talk about the decoder. If you think that our approach above is too complicated, it is actually possible for you to use nn.TransformerEncoderLayer from PyTorch so that you don’t need to implement the EncoderBlock class from scratch. To do so, I am going to reimplement the Encoder class, but this time I’ll name it EncoderTorch.

# Codeblock 11
class EncoderTorch(nn.Module):
   def __init__(self):
       super().__init__()
       self.patcher = Patcher()
       self.learnable_embedding = LearnableEmbedding()
      
       #(1)
       encoder_block = nn.TransformerEncoderLayer(d_model=EMBED_DIM,
                                                  nhead=NUM_HEADS,
                                                  dim_feedforward=HIDDEN_DIM,
                                                  dropout=DROP_PROB,
                                                  batch_first=True)
      
       #(2)
       self.encoder_blocks = nn.TransformerEncoder(encoder_layer=encoder_block,
                                                   num_layers=NUM_ENCODER_BLOCKS)
  
   def forward(self, images):
       print(f'images\t\t\t: {images.size()}')
      
       features = self.patcher(images)
       print(f'after patcher\t\t: {features.size()}')
      
       features = features + self.learnable_embedding()
       print(f'after learn embed\t: {features.size()}')
      
       features = self.encoder_blocks(features)  #(3)
       print(f'after encoder blocks\t: {features.size()}')

       return features

What we basically do in the above codeblock is that instead of using the EncoderBlock class, here we use nn.TransformerEncoderLayer (#(1)), which will automatically create a single encoder block based on the parameters we pass to it. To repeat it multiple times, we can just use nn.TransformerEncoder and pass a number to the num_layers parameter (#(2)). With this approach, we don’t necessarily need to write the forward pass in a loop like what we did earlier (#(3)).

The testing code in the Codeblock 12 below is exactly the same as the one in Codeblock 10, except that here I use the EncoderTorch class. You can also see here that the output is basically the same as the previous one.

# Codeblock 12
encoder_torch = EncoderTorch()

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = encoder_torch(images)
# Codeblock 12 Output
images               : torch.Size([1, 3, 384, 384])
after patcher        : torch.Size([1, 576, 768])
after learn embed    : torch.Size([1, 576, 768])
after encoder blocks : torch.Size([1, 576, 768])

Decoder

As we have successfully created the encoder part of the CPTR architecture, now that we will talk about the decoder. In this section I am going to implement every single component inside the blue box in Figure 4. Based on the figure, we can see that the decoder accepts two inputs, i.e., the image caption ground truth (the lower part of the blue box) and the sequence of embedded patches produced by the encoder (the arrow coming from the green box). It is important to know that the architecture drawn in Figure 4 is intended to illustrate the training phase, where the entire caption ground truth is fed into the decoder. Later in the inference phase, we only provide a <BOS> (Beginning of Sentence) token for the caption input. The decoder will then predict each word sequentially based on the given image and the previously generated words. This process is commonly known as an autoregressive mechanism.

Sinusoidal positional embedding

Figure 10. Where the sinusoidal positional embedding component is located in the decoder [5].

If you take a look at the CPTR model, you’ll see that the first step in the decoder is to convert each word into the corresponding feature vector representation using the word embedding block. However, since this step is very easy, we are going to implement it later. Now let’s assume that this word vectorization process is already done, so we can move to the positional embedding part.

As I’ve mentioned earlier, since transformer is permutation-invariant by nature, we need to apply positional embedding to the input sequence. Different from the previous one, here we use the so-called sinusoidal positional embedding. We can think of it like a method to label each word vector by assigning numbers obtained from a sinusoidal wave. By doing so, we can expect our model to understand word orders thanks to the information given by the wave patterns.

If you go back to Codeblock 6 Output, you’ll see that the positional embedding tensor in the encoder has the size of NUM_PATCHES × EMBED_DIM (576×768). What we basically want to do in the decoder is to create a tensor having the size of SEQ_LENGTH × EMBED_DIM (30×768), which the values are computed based on the equation shown in Figure 11. This tensor is then set to be non-trainable because a sequence of words must maintain a fixed order to preserve its meaning.

Figure 11. The equation for creating sinusoidal positional encoding proposed in the Transformer paper [6].

Here I want to explain the following code quickly because I actually have discussed this more thoroughly in my previous article about Transformer. Generally speaking, what we basically do here is to create the sine and cosine wave using torch.sin() (#(1)) and torch.cos() (#(2)). The resulting two tensors are then merged using the code at line #(3) and #(4).

# Codeblock 13
class SinusoidalEmbedding(nn.Module):
   def forward(self):
       pos = torch.arange(SEQ_LENGTH).reshape(SEQ_LENGTH, 1)
       print(f"pos\t\t: {pos.shape}")
      
       i = torch.arange(0, EMBED_DIM, 2)
       denominator = torch.pow(10000, i/EMBED_DIM)
       print(f"denominator\t: {denominator.shape}")
      
       even_pos_embed = torch.sin(pos/denominator)  #(1)
       odd_pos_embed  = torch.cos(pos/denominator)  #(2)
       print(f"even_pos_embed\t: {even_pos_embed.shape}")
      
       stacked = torch.stack([even_pos_embed, odd_pos_embed], dim=2)  #(3)
       print(f"stacked\t\t: {stacked.shape}")

       pos_embed = torch.flatten(stacked, start_dim=1, end_dim=2)  #(4)
       print(f"pos_embed\t: {pos_embed.shape}")
      
       return pos_embed

Now we can check if the SinusoidalEmbedding class above works properly by running the Codeblock 14 below. As expected earlier, here you can see that the resulting tensor has the size of 30×768. This dimension matches with the tensor obtained by the process done in the word embedding block, allowing them to be summed in an element-wise manner.

# Codeblock 14
sinusoidal_embedding = SinusoidalEmbedding()
pos_embed = sinusoidal_embedding()
# Codeblock 14 Output
pos            : torch.Size([30, 1])
denominator    : torch.Size([384])
even_pos_embed : torch.Size([30, 384])
stacked        : torch.Size([30, 384, 2])
pos_embed      : torch.Size([30, 768])

Look-ahead mask

Figure 12. A look-ahead mask needs to be applied to the masked-self attention layer [5].

The next thing I am going to talk about in the decoder is the masked self-attention layer highlighted in the above figure. I am not going to code the attention mechanism from scratch. Rather, I’ll only implement the so-called look-ahead mask, which will be useful for the self-attention layer so that it doesn’t attend to the subsequent words in the caption during the training phase.

The way to do it is pretty easy, what we need to do is just to create a triangular matrix which the size is set to match with the attention weight matrix, i.e., SEQ_LENGTH × SEQ_LENGTH (30×30). Look at the create_mask()function below for the details.

# Codeblock 15
def create_mask(seq_length):
   mask = torch.tril(torch.ones((seq_length, seq_length)))  #(1)
   mask[mask == 0] = -float('inf')  #(2)
   mask[mask == 1] = 0  #(3)
   return mask

Even though creating a triangular matrix can simply be done with torch.tril() and torch.ones() (#(1)), but here we need to make a little modification by changing the 0 values to -inf (#(2)) and the 1s to 0 (#(3)). This is essentially done because the nn.MultiheadAttention layer applies the mask by element-wise addition. By assigning -inf to the subsequent words, the attention mechanism will completely ignore them. Again, the internal process inside an attention layer has also been discussed in detail in my previous article about transformer.

Now I am going to run the function with seq_length=7 so that you can see what the mask actually looks like. Later in the complete flow, we need to set the seq_length parameter to SEQ_LENGTH (30) so that it matches with the actual caption length.

# Codeblock 16
mask_example = create_mask(seq_length=7)
mask_example
# Codeblock 16 Output
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf],
       [0., 0., -inf, -inf, -inf, -inf, -inf],
       [0., 0., 0., -inf, -inf, -inf, -inf],
       [0., 0., 0., 0., -inf, -inf, -inf],
       [0., 0., 0., 0., 0., -inf, -inf],
       [0., 0., 0., 0., 0., 0., -inf],
       [0., 0., 0., 0., 0., 0., 0.]])

The main decoder block

Figure 13. The main decoder block [5].

We can see in the above figure that the structure of the decoder block is a bit longer than that of the encoder block. It seems like everything is nearly the same, except that the decoder part has a cross-attention mechanism and an additional layer normalization step placed after it. This cross-attention layer can actually be perceived as the bridge between the encoder and the decoder, as it is employed to capture the relationships between each word in the caption and every single patch in the input image. The two arrows coming from the encoder are the key and value inputs for the attention layer, whereas the query is derived from the previous layer in the decoder itself. Look at the Codeblock 17a and 17b below to see the implementation of the entire decoder block.

# Codeblock 17a
class DecoderBlock(nn.Module):
   def __init__(self):
       super().__init__()
      
       #(1)
       self.self_attention = nn.MultiheadAttention(embed_dim=EMBED_DIM,
                                                   num_heads=NUM_HEADS,
                                                   batch_first=True,
                                                   dropout=DROP_PROB)
       #(2)
       self.layer_norm_0 = nn.LayerNorm(EMBED_DIM)
       #(3)
       self.cross_attention = nn.MultiheadAttention(embed_dim=EMBED_DIM,
                                                    num_heads=NUM_HEADS,
                                                    batch_first=True,
                                                    dropout=DROP_PROB)

       #(4)
       self.layer_norm_1 = nn.LayerNorm(EMBED_DIM)
      
       #(5)      
       self.ffn = nn.Sequential(
           nn.Linear(in_features=EMBED_DIM, out_features=HIDDEN_DIM),
           nn.GELU(),
           nn.Dropout(p=DROP_PROB),
           nn.Linear(in_features=HIDDEN_DIM, out_features=EMBED_DIM),
       )
      
       #(6)
       self.layer_norm_2 = nn.LayerNorm(EMBED_DIM)

In the __init__() method, we first initialize both self-attention (#(1)) and cross-attention (#(3)) layers with nn.MultiheadAttention. These two layers appear to be exactly the same now, but later you’ll see the difference in the forward() method. The three layer normalization operations are initialized separately as shown at line #(2), #(4) and #(6), since each of them will contain different normalization parameters. Lastly, the ffn layer (#(5)) is exactly the same as the one in the encoder, which basically follows the equation back in Figure 8.

Talking about the forward() method below, it initially works by accepting three inputs: features, captions, and attn_mask, which each of them denotes the tensor coming from the encoder, the tensor from the decoder itself, and a look-ahead mask, respectively (#(1)). The remaining steps are somewhat similar to that of the EncoderBlock, except that here we repeat the multihead attention block twice. The first attention mechanism takes captions as the query, key, and value parameters (#(2)). This is essentially done because we want the layer to capture the context within the captions tensor itself — hence the name self-attention. Here we also need to pass the attn_mask parameter to this layer so that it cannot see the subsequent words during the training phase. The second attention mechanism is different (#(3)). Since we want to combine the information from the encoder and the decoder, we need to pass the captions tensor as the query, whereas the features tensor will be passed as the key and value — hence the name cross-attention. A look-ahead mask is not necessary in the cross-attention layer since later in the inference phase the model will be able to see the entire input image at once rather than looking at the patches one by one. As the tensor has been processed by the two attention layers, we will then pass it through the feed forward network (#(4)). Lastly, don’t forget to create the residual connections and apply the layer normalization steps after each sub-component.

# Codeblock 17b
   def forward(self, features, captions, attn_mask):  #(1)
       print(f"attn_mask\t\t: {attn_mask.shape}")
       residual = captions
       print(f"captions & residual\t: {captions.shape}")
      
       #(2)
       captions, self_attn_weights = self.self_attention(query=captions,
                                                         key=captions,
                                                         value=captions,
                                                         attn_mask=attn_mask)
       print(f"after self attention\t: {captions.shape}")
       print(f"self attn weights\t: {self_attn_weights.shape}")
      
       captions = self.layer_norm_0(captions + residual)
       print(f"after norm\t\t: {captions.shape}")
      
      
       print(f"\nfeatures\t\t: {features.shape}")
       residual = captions
       print(f"captions & residual\t: {captions.shape}")
      
       #(3)
       captions, cross_attn_weights = self.cross_attention(query=captions,
                                                           key=features,
                                                           value=features)
       print(f"after cross attention\t: {captions.shape}")
       print(f"cross attn weights\t: {cross_attn_weights.shape}")
      
       captions = self.layer_norm_1(captions + residual)
       print(f"after norm\t\t: {captions.shape}")
      
       residual = captions
       print(f"\ncaptions & residual\t: {captions.shape}")
      
       captions = self.ffn(captions)  #(4)
       print(f"after ffn\t\t: {captions.shape}")
      
       captions = self.layer_norm_2(captions + residual)
       print(f"after norm\t\t: {captions.shape}")
      
       return captions

As the DecoderBlock class is completed, we can now test it with the following code.

# Codeblock 18
decoder_block = DecoderBlock()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)  #(1)
captions = torch.randn(BATCH_SIZE, SEQ_LENGTH, EMBED_DIM)   #(2)
look_ahead_mask = create_mask(seq_length=SEQ_LENGTH)  #(3)

captions = decoder_block(features, captions, look_ahead_mask)

Here we assume that features is a tensor containing a sequence of patch embeddings produced by the encoder (#(1)), while captions is a sequence of embedded words (#(2)). The seq_length parameter of the look-ahead mask is set to SEQ_LENGTH (30) to match it to the number of words in the caption (#(3)). The tensor dimensions after each step are displayed in the following output.

# Codeblock 18 Output
attn_mask             : torch.Size([30, 30])
captions & residual   : torch.Size([1, 30, 768])
after self attention  : torch.Size([1, 30, 768])
self attn weights     : torch.Size([1, 30, 30])    #(1)
after norm            : torch.Size([1, 30, 768])

features              : torch.Size([1, 576, 768])
captions & residual   : torch.Size([1, 30, 768])
after cross attention : torch.Size([1, 30, 768])
cross attn weights    : torch.Size([1, 30, 576])   #(2)
after norm            : torch.Size([1, 30, 768])

captions & residual   : torch.Size([1, 30, 768])
after ffn             : torch.Size([1, 30, 768])
after norm            : torch.Size([1, 30, 768])

Here we can see that our DecoderBlock class works properly as it successfully processed the input tensors all the way to the last layer in the network. Here I want you to take a closer look at the attention weights at lines #(1) and #(2). Based on these two lines, we can confirm that our decoder implementation is correct since the attention weight produced by the self-attention layer has the size of 30×30 (#(1)), which basically means that this layer really captured the context within the input caption. Meanwhile, the attention weight matrix generated by the cross-attention layer has the size of 30×576 (#(2)), indicating that it successfully captured the relationships between the words and the patches. This essentially implies that after cross-attention operation is performed, the resulting captions tensor has been enriched with the information from the image.

Transformer decoder

Figure 14. The entire Transformer Decoder in the CPTR architecture [5].

Now that we have successfully created all components for the entire decoder, what I am going to do next is to put them together into a single class. Look at the Codeblock 19a and 19b below to see how I do that.

# Codeblock 19a
class Decoder(nn.Module):
   def __init__(self):
       super().__init__()

       #(1)
       self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE,
                                     embedding_dim=EMBED_DIM)

       #(2)
       self.sinusoidal_embedding = SinusoidalEmbedding()

       #(3)
       self.decoder_blocks = nn.ModuleList(DecoderBlock() for _ in range(NUM_DECODER_BLOCKS))

       #(4)
       self.linear = nn.Linear(in_features=EMBED_DIM,
                               out_features=VOCAB_SIZE)

If you compare this Decoder class with the Encoder class from codeblock 9, you’ll notice that they are somewhat similar in terms of the structure. In the encoder, we convert image patches into vectors using Patcher, while in the decoder we convert every single word in the caption into a vector using the nn.Embedding layer (#(1)), which I haven’t explained earlier. Afterward, we initialize the positional embedding layer, where for the decoder we use the sinusoidal rather than the trainable one (#(2)). Next, we stack multiple decoder blocks using nn.ModuleList (#(3)). The linear layer written at line #(4), which doesn’t exist in the encoder, is necessary to be implemented here since it will be responsible to map each of the embedded words into a vector of length VOCAB_SIZE (10000). Later on, this vector will contain the logit of every word in the dictionary, and what we need to do afterward is just to take the index containing the highest value, i.e., the most likely word to be predicted.

The flow of the tensors within the forward() method itself is also pretty similar to the one in the Encoder class. In the Codeblock 19b below we pass features, captions, and attn_mask as the input (#(1)). Keep in mind that in this case the captions tensor contains the raw word sequence, so we need to vectorize these words with the embedding layer beforehand (#(2)). Next, we inject the sinusoidal positional embedding tensor using the code at line #(3) before eventually passing it through the four decoder blocks sequentially (#(4)). Finally, we pass the resulting tensor through the last linear layer to obtain the prediction logits (#(5)).

# Codeblock 19b
   def forward(self, features, captions, attn_mask):  #(1)
       print(f"features\t\t: {features.shape}")
       print(f"captions\t\t: {captions.shape}")
      
       captions = self.embedding(captions)  #(2)
       print(f"after embedding\t\t: {captions.shape}")
      
       captions = captions + self.sinusoidal_embedding()  #(3)
       print(f"after sin embed\t\t: {captions.shape}")
      
       for i, decoder_block in enumerate(self.decoder_blocks):
           captions = decoder_block(features, captions, attn_mask)  #(4)
           print(f"after decoder block #{i}\t: {captions.shape}")
      
       captions = self.linear(captions)  #(5)
       print(f"after linear\t\t: {captions.shape}")
      
       return captions

At this point you might be wondering why we don’t implement the softmax activation function as drawn in the illustration. This is essentially because during the training phase, softmax is typically included within the loss function, whereas in the inference phase, the index of the largest value will remain the same regardless of whether softmax is applied.

Now let’s run the following testing code to check whether there are errors in our implementation. Previously I mentioned that the captions input of the Decoder class is a raw word sequence. To simulate this, we can simply create a sequence of random integers ranging between 0 and VOCAB_SIZE (10000) with the length of SEQ_LENGTH (30) words (#(1)).

# Codeblock 20
decoder = Decoder()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  #(1)

captions = decoder(features, captions, look_ahead_mask)

And below is what the resulting output looks like. Here you can see in the last line that the linear layer produced a tensor of size 30×10000, indicating that our decoder model is now capable of predicting the logit scores for each word in the vocabulary across all 30 sequence positions.

# Codeblock 20 Output
features               : torch.Size([1, 576, 768])
captions               : torch.Size([1, 30])
after embedding        : torch.Size([1, 30, 768])
after sin embed        : torch.Size([1, 30, 768])
after decoder block #0 : torch.Size([1, 30, 768])
after decoder block #1 : torch.Size([1, 30, 768])
after decoder block #2 : torch.Size([1, 30, 768])
after decoder block #3 : torch.Size([1, 30, 768])
after linear           : torch.Size([1, 30, 10000])

Transformer decoder (alternative)

It is actually also possible to make the code simpler by replacing the DecoderBlock class with the nn.TransformerDecoderLayer, just like what we did in the ViT Encoder. Below is what the code looks like if we use this approach instead.

# Codeblock 21
class DecoderTorch(nn.Module):
   def __init__(self):
       super().__init__()
       self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE,
                                     embedding_dim=EMBED_DIM)
      
       self.sinusoidal_embedding = SinusoidalEmbedding()
      
       #(1)
       decoder_block = nn.TransformerDecoderLayer(d_model=EMBED_DIM,
                                                  nhead=NUM_HEADS,
                                                  dim_feedforward=HIDDEN_DIM,
                                                  dropout=DROP_PROB,
                                                  batch_first=True)
      
       #(2)
       self.decoder_blocks = nn.TransformerDecoder(decoder_layer=decoder_block,
                                                   num_layers=NUM_DECODER_BLOCKS)
      
       self.linear = nn.Linear(in_features=EMBED_DIM,
                               out_features=VOCAB_SIZE)
      
   def forward(self, features, captions, tgt_mask):
       print(f"features\t\t: {features.shape}")
       print(f"captions\t\t: {captions.shape}")
      
       captions = self.embedding(captions)
       print(f"after embedding\t\t: {captions.shape}")
      
       captions = captions + self.sinusoidal_embedding()
       print(f"after sin embed\t\t: {captions.shape}")
      
       #(3)
       captions = self.decoder_blocks(tgt=captions,
                                      memory=features,
                                      tgt_mask=tgt_mask)
       print(f"after decoder blocks\t: {captions.shape}")
      
       captions = self.linear(captions)
       print(f"after linear\t\t: {captions.shape}")
      
       return captions

The main difference you will see in the __init__() method is the use of nn.TransformerDecoderLayer and nn.TransformerDecoder at line #(1) and #(2), where the former is used to initialize a single decoder block, and the latter is for repeating the block multiple times. Next, the forward() method is mostly similar to the one in the Decoder class, except that the forward propagation on the decoder blocks is automatically repeated four times without needing to be put inside a loop (#(3)). One thing that you need to pay attention to in the decoder_blocks layer is that the tensor coming from the encoder (features) must be passed as the argument for the memory parameter. Meanwhile, the tensor from the decoder itself (captions) has to be passed as the input to the tgt parameter.

The testing code for the DecoderTorch model below is basically the same as the one written in Codeblock 20. Here you can see that this model also generates the final output tensor of size 30×10000.

# Codeblock 22
decoder_torch = DecoderTorch()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))

captions = decoder_torch(features, captions, look_ahead_mask)
# Codeblock 22 Output
features             : torch.Size([1, 576, 768])
captions             : torch.Size([1, 30])
after embedding      : torch.Size([1, 30, 768])
after sin embed      : torch.Size([1, 30, 768])
after decoder blocks : torch.Size([1, 30, 768])
after linear         : torch.Size([1, 30, 10000])

The entire CPTR model

Finally, it’s time to put the encoder and the decoder part we just created into a single class to actually construct the CPTR architecture. You can see in Codeblock 23 below that the implementation is very simple. All we need to do here is just to initialize the encoder (#(1)) and the decoder (#(2)) components, then pass the raw images and the corresponding caption ground truths as well as the look-ahead mask to the forward() method (#(3)). Additionally, it is also possible for you to replace the Encoder and the Decoder with EncoderTorch and DecoderTorch, respectively.

# Codeblock 23
class EncoderDecoder(nn.Module):
   def __init__(self):
       super().__init__()
       self.encoder = Encoder()  #EncoderTorch()  #(1)
       self.decoder = Decoder()  #DecoderTorch()  #(2)
      
   def forward(self, images, captions, look_ahead_mask):  #(3)
       print(f"images\t\t\t: {images.shape}")
       print(f"captions\t\t: {captions.shape}")
      
       features = self.encoder(images)
       print(f"after encoder\t\t: {features.shape}")
      
       captions = self.decoder(features, captions, look_ahead_mask)
       print(f"after decoder\t\t: {captions.shape}")
      
       return captions

We can do the testing by passing dummy tensors through it. See the Codeblock 24 below for the details. In this case, images is basically just a tensor of random numbers having the dimension of 1×3×384×384 (#(1)), while captions is a tensor of size 1×30 containing random integers (#(2)).

# Codeblock 24
encoder_decoder = EncoderDecoder()

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)  #(1)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  #(2)

captions = encoder_decoder(images, captions, look_ahead_mask)

Below is what the output looks like. We can see here that our input images and captions successfully went through all layers in the network, which basically means that the CPTR model we created is now ready to actually be trained on image captioning datasets.

# Codeblock 24 Output
images         : torch.Size([1, 3, 384, 384])
captions       : torch.Size([1, 30])
after encoder  : torch.Size([1, 576, 768])
after decoder  : torch.Size([1, 30, 10000])

Ending

That was pretty much everything about the theory and implementation of the CaPtion TransformeR architecture. Let me know what deep learning architecture I should implement next. Feel free to leave a comment if you spot any mistakes in this article!

The code used in this article is available in my GitHub repo. Here’s the link to my previous article about image captioning, Vision Transformer (ViT), and the original Transformer.

References

[1] Wei Liu et al. CPTR: Full Transformer Network for Image Captioning. Arxiv. https://arxiv.org/pdf/2101.10804 [Accessed November 16, 2024].

[2] Oriol Vinyals et al. Show and Tell: A Neural Image Caption Generator. Arxiv. https://arxiv.org/pdf/1411.4555 [Accessed December 3, 2024].

[3] Image originally created by author based on: Alexey Dosovitskiy et al. An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale. Arxiv. https://arxiv.org/pdf/2010.11929 [Accessed December 3, 2024].

[4] Image originally created by author based on [6].

[5] Image originally created by author based on [1].

[6] Ashish Vaswani et al. Attention Is All You Need. Arxiv. https://arxiv.org/pdf/1706.03762 [Accessed December 3, 2024].

The post Image Captioning, Transformer Mode On appeared first on Towards Data Science.

]]>
Show and Tell https://towardsdatascience.com/show-and-tell-e1a1142456e2/ Mon, 03 Feb 2025 16:30:24 +0000 https://towardsdatascience.com/show-and-tell-e1a1142456e2/ Implementing one of the earliest neural image caption generator models with PyTorch.

The post Show and Tell appeared first on Towards Data Science.

]]>
Photo by Ståle Grut on Unsplash
Photo by Ståle Grut on Unsplash

Introduction

Natural Language Processing and Computer Vision used to be two completely different fields. Well, at least back when I started to learn machine learning and deep learning, I feel like there are multiple paths to follow, and each of them, including NLP and Computer Vision, directs me to a completely different world. Over time, we can now observe that AI becomes more and more advanced, with the intersection between multiple fields of study getting more common, including the two I just mentioned.

Today, many language models have capability to generate images based on the given prompt. That’s one example of the bridge between NLP and Computer Vision. But I guess I’ll save it for my upcoming article as it is a bit more complex. Instead, in this article I am going to discuss the simpler one: image captioning. As the name suggests, this is essentially a technique where a specific model accepts an image and returns a text that describes the input image.

One of the earliest papers in this topic is the one titled "Show and Tell: A Neural Image Caption Generator" written by Vinyals et al. back in 2015 [1]. In this article, I will focus on implementing the Deep Learning model proposed in the paper using PyTorch. Note that I won’t actually demonstrate the training process here as that’s a topic on its own. Let me know in the comments if you want a separate tutorial on that.


Image Captioning Framework

Generally speaking, image captioning can be done by combining two types of models: the one specialized to process images and another one capable of processing sequences. I believe you already know what kind of models work best for the two tasks – yes, you’re right, those are CNN and RNN, respectively. The idea here is that the CNN is utilized to encode the input image (hence this part is called encoder), whereas the RNN is used for generating a sequence of words based on the features encoded by the CNN (hence the RNN part is called decoder).

It is discussed in the paper that the authors attempted to do so using GoogLeNet (a.k.a., Inception V1) for the encoder and LSTM for the decoder. In fact, the use of GoogLeNet is not explicitly mentioned, yet based on the illustration provided in the paper it seems like the architecture used in the encoder is adopted from the original GoogLeNet paper [2]. The figure below shows what the proposed architecture looks like.

Figure 1. The image captioning model proposed in [1], where the encoder part (the leftmost block) implements the GoogLeNet model [2].
Figure 1. The image captioning model proposed in [1], where the encoder part (the leftmost block) implements the GoogLeNet model [2].

Talking more specifically about the connection between the encoder and the decoder, there are several methods available for connecting the two, namely init-inject, pre-inject, par-inject and merge, as mentioned in [3]. In the case of the Show and Tell paper, authors used pre-inject, a method where the features extracted by the encoder are perceived as the 0th word in the caption. Later in the inference phase, we expect the decoder to generate a caption based solely on these image features.

Figure 2. The four methods possible to be used to connect the encoder and the decoder part of an image captioning model [3]. In our case we are going to use the pre-inject method (b).
Figure 2. The four methods possible to be used to connect the encoder and the decoder part of an image captioning model [3]. In our case we are going to use the pre-inject method (b).

As we already understood the theory behind the image captioning model, we can now jump into the code!


Implementation

I’ll break the implementation part into three sections: the Encoder, the Decoder, and the combination of the two. Before we actually get into them, we need to import the modules and initialize the required parameters in advance. Look at the Codeblock 1 below to see the modules I use.

# Codeblock 1
import torch  #(1)
import torch.nn as nn  #(2)
import torchvision.models as models  #(3)
from torchvision.models import GoogLeNet_Weights  #(4)

Let’s break down these imports quickly: the line marked with #(1) is used for basic operations, line #(2) is for initializing neural network layers, line #(3) is for loading various deep learning models, and #(4) is the pretrained weights for the GoogLeNet model.

Talking about the parameter configuration, EMBED_DIM and LSTM_HIDDEN_DIM are the only two parameters mentioned in the paper, which are both set to 512 as shown at line #(1) and #(2) in the Codeblock 2 below. The EMBED_DIM variable essentially indicates the feature vector size representing a single token in the caption. In this case, we can simply think of a single token as an individual word. Meanwhile, LSTM_HIDDEN_DIM is a variable representing the hidden state size inside the LSTM cell. This paper does not mention how many times this RNN-based layer is repeated, but based on the diagram in Figure 1, it seems like it only implements a single LSTM cell. Thus, at line #(3) I set the NUM_LSTM_LAYERS variable to 1.

# Codeblock 2
EMBED_DIM       = 512    #(1)
LSTM_HIDDEN_DIM = 512    #(2)
NUM_LSTM_LAYERS = 1      #(3)

IMAGE_SIZE      = 224    #(4)
IN_CHANNELS     = 3      #(5)

SEQ_LENGTH      = 30     #(6)
VOCAB_SIZE      = 10000  #(7)

BATCH_SIZE      = 1

The next two parameters are related to the input image, namely IMAGE_SIZE (#(4)) and IN_CHANNELS (#(5)). Since we are about to use GoogLeNet for the encoder, we need to match it with its original input shape (3×224×224). Not only for the image, but we also need to configure the parameters for the caption. Here we assume that the caption length is no more than 30 words (#(6)) and the number of unique words in the dictionary is 10000 (#(7)). Lastly, the BATCH_SIZE parameter is used because by default PyTorch processes tensors in a batch. Just to make things simple, the number of image-caption pair within a single batch is set to 1.

GoogLeNet Encoder

It is actually possible to use any kind of CNN-based model for the encoder. I found on the internet that [4] uses DenseNet, [5] uses Inception V3, and [6] utilizes ResNet for the similar tasks. However, since my goal is to reproduce the model proposed in the paper as closely as possible, I am using the pretrained GoogLeNet model instead. Before we get into the encoder implementation, let’s see what the GoogLeNet architecture looks like using the following code.

# Codeblock 3
models.googlenet()

The resulting output is very long as it lists literally all layers inside the architecture. Here I truncate the output since I only want you to focus on the last layer (the fc layer marked with #(1) in the Codeblock 3 Output below). You can see that this linear layer maps a feature vector of size 1024 into 1000. Normally, in a standard image classification task, each of these 1000 neurons corresponds to a specific class. So, for example, if you want to perform a 5-class classification task, you would need to modify this layer such that it projects the outputs to 5 neurons only. In our case, we need to make this layer produce a feature vector of length 512 (EMBED_DIM). With this, the input image will later be represented as a 512-dimensional vector after being processed by the GoogLeNet model. This feature vector size will exactly match with the token embedding dimension, allowing it to be treated as a part of our word sequence.

# Codeblock 3 Output
GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )

  .
  .
  .
  .

  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=1024, out_features=1000, bias=True)  #(1)
)

Now let’s actually load and modify the GoogLeNet model, which I do in the InceptionEncoder class below.

# Codeblock 4a
class InceptionEncoder(nn.Module):
    def __init__(self, fine_tune):  #(1)
        super().__init__()
        self.googlenet = models.googlenet(weights=GoogLeNet_Weights.IMAGENET1K_V1)  #(2)
        self.googlenet.fc = nn.Linear(in_features=self.googlenet.fc.in_features,  #(3)
                                      out_features=EMBED_DIM)  #(4)

        if fine_tune == True:       #(5)
            for param in self.googlenet.parameters():
                param.requires_grad = True
        else:
            for param in self.googlenet.parameters():
                param.requires_grad = False

        for param in self.googlenet.fc.parameters():
            param.requires_grad = True

The first thing we do in the above code is to load the model using models.googlenet(). It is mentioned in the paper that the model is already pretrained on the ImageNet dataset. Thus, we need to pass GoogLeNet_Weights.IMAGENET1K_V1 into the weights parameter, as shown at line #(2) in Codeblock 4a. Next, at line #(3) we access the classification head through the fc attribute, where we replace the existing linear layer with a new one having the output dimension of 512 (EMBED_DIM) (#(4)). Since this GoogLeNet model is already trained, we don’t need to train it from scratch. Instead, we can either perform fine-tuning or transfer learning in order to adapt it to the image captioning task.

In case you’re not yet familiar with the two terms, fine-tuning is a method where we update the weights of the entire model. On the other hand, transfer learning is a technique where we only update the weights of the layers we replaced (in this case it’s the last fully-connected layer), while setting the weights of the existing layers non-trainable. To do so, I implement a flag named fine_tune at line #(1) which will let the model to perform fine-tuning whenever it is set to True (#(5)).

The forward() method is pretty straightforward since what we do here is simply passing the input image through the modified GoogLeNet model. See the Codeblock 4b below for the details. Additionally, here I also print out the tensor dimension before and after processing so that you can better understand how the InceptionEncoder model works.

# Codeblock 4b
    def forward(self, images):
        print(f'originalt: {images.size()}')
        features = self.googlenet(images)
        print(f'after googlenett: {features.size()}')

        return features

To test whether our decoder works properly, we can pass a dummy tensor of size 1×3×224×224 through the network as demonstrated in Codeblock 5. This tensor dimension simulates a single RGB image of size 224×224. You can see in the resulting output that our image now becomes a single-dimensional feature vector with the length of 512.

# Codeblock 5
inception_encoder = InceptionEncoder(fine_tune=True)

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = inception_encoder(images)
# Codeblock 5 Output
original         : torch.Size([1, 3, 224, 224])
after googlenet  : torch.Size([1, 512])

LSTM Decoder

As we have successfully implemented the encoder, now that we are going to create the LSTM decoder, which I demonstrate in Codeblock 6a and 6b. What we need to do first is to initialize the required layers, namely an embedding layer (#(1)), the LSTM layer itself (#(2)), and a standard linear layer (#(3)). The first one (nn.Embedding) is responsible for mapping every single token into a 512 (EMBED_DIM)-dimensional vector. Meanwhile, the LSTM layer is going to generate a sequence of embedded tokens, where each of these tokens will be mapped into a 10000 (VOCAB_SIZE)-dimensional vector by the linear layer. Later on, the values contained in this vector will represent the likelihood of each word in the dictionary being chosen.

# Codeblock 6a
class LSTMDecoder(nn.Module):
    def __init__(self):
        super().__init__()

        #(1)
        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE,
                                      embedding_dim=EMBED_DIM)
        #(2)
        self.lstm = nn.LSTM(input_size=EMBED_DIM, 
                            hidden_size=LSTM_HIDDEN_DIM, 
                            num_layers=NUM_LSTM_LAYERS, 
                            batch_first=True)
        #(3)        
        self.linear = nn.Linear(in_features=LSTM_HIDDEN_DIM, 
                                out_features=VOCAB_SIZE)

Next, let’s define the flow of the network using the following code.

# Codeblock 6b
    def forward(self, features, captions):                 #(1)
        print(f'features originalt: {features.size()}')
        features = features.unsqueeze(1)                   #(2)
        print(f"after unsqueezett: {features.shape}")

        print(f'captions originalt: {captions.size()}')
        captions = self.embedding(captions)                #(3)
        print(f"after embeddingtt: {captions.shape}")

        captions = torch.cat([features, captions], dim=1)  #(4)
        print(f"after concattt: {captions.shape}")

        captions, _ = self.lstm(captions)                  #(5)
        print(f"after lstmtt: {captions.shape}")

        captions = self.linear(captions)                   #(6)
        print(f"after lineartt: {captions.shape}")

        return captions

You can see in the above code that the forward() method of the LSTMDecoder class accepts two inputs: features and captions, where the former is the image that has been processed by the InceptionEncoder, while the latter is the caption of the corresponding image serving as the ground truth (#(1)). The idea here is that we are going to perform pre-inject operation by prepending the features tensor into captions using the code at line #(4). However, keep in mind that we need to adjust the shape of both tensors beforehand. To do so, we have to insert a single dimension at the 1st axis of the image features (#(2)). Meanwhile, the shape of the captions tensor will align with our requirement right after being processed by the embedding layer (#(3)). As the features and captions have been concatenated, we then pass this tensor through the LSTM layer (#(5)) before it is eventually processed by the linear layer (#(6)). Look at the testing code below to better understand the flow of the two tensors.

# Codeblock 7
lstm_decoder = LSTMDecoder()

features = torch.randn(BATCH_SIZE, EMBED_DIM)  #(1)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  #(2)

captions = lstm_decoder(features, captions)

In Codeblock 7, I assume that features is a dummy tensor that represents the output of the InceptionEncoder model (#(1)). Meanwhile, captions is the tensor representing a sequence of tokenized words, where in this case I initialize it as random numbers ranging between 0 to 10000 (VOCAB_SIZE) with the length of 30 (SEQ_LENGTH) (#(2)).

We can see in the output below that the features tensor initially has the dimension of 1×512 (#(1)). This tensor shape changed to 1×1×512 after being processed with the unsqueeze() operation (#(2)). The additional dimension in the middle (1) allows the tensor to be treated as a feature vector corresponding to a single timestep, which is necessary for compatibility with the LSTM layer. To the captions tensor, its shape changed from 1×30 (#(3)) to 1×30×512 (#(4)), indicating that every single word is now represented as a 512-dimensional vector.

# Codeblock 7 Output
features original : torch.Size([1, 512])       #(1)
after unsqueeze   : torch.Size([1, 1, 512])    #(2)
captions original : torch.Size([1, 30])        #(3)
after embedding   : torch.Size([1, 30, 512])   #(4)
after concat      : torch.Size([1, 31, 512])   #(5)
after lstm        : torch.Size([1, 31, 512])   #(6)
after linear      : torch.Size([1, 31, 10000]) #(7)

After pre-inject operation is performed, our tensor is now having the dimension of 1×31×512, where the features tensor becomes the token at the 0th timestep in the sequence (#(5)). See the following figure to better illustrate this idea.

Figure 3. What the resulting tensor looks like after the pre-injection operation. [3].
Figure 3. What the resulting tensor looks like after the pre-injection operation. [3].

Next, we pass the tensor through the LSTM layer, which in this particular case the output tensor dimension remains the same. However, it is important to note that the tensor shapes at line #(5) and #(6) in the above output are actually specified by different parameters. The dimensions appear to match here because EMBED_DIM and LSTM_HIDDEN_DIM were both set to 512. Normally, if we use a different value for LSTM_HIDDEN_DIM, then the output dimension is going to be different as well. Finally, we projected each of the 31 token embeddings to a vector of size 10000, which will later contain the probability of every possible token being predicted (#(7)).

GoogLeNet Encoder + LSTM Decoder

At this point, we have successfully created both the encoder and the decoder parts of the image captioning model. What I am going to do next is to combine them together in the ShowAndTell class below.

# Codeblock 8a
class ShowAndTell(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = InceptionEncoder(fine_tune=True)  #(1)
        self.decoder = LSTMDecoder()     #(2)

    def forward(self, images, captions):
        features = self.encoder(images)  #(3)
        print(f"after encodert: {features.shape}")

        captions = self.decoder(features, captions)      #(4)
        print(f"after decodert: {captions.shape}")

        return captions

I think the above code is pretty straightforward. In the __init__() method, we only need to initialize the InceptionEncoder as well as the LSTMDecoder models (#(1) and #(2)). Here I assume that we are about to perform fine-tuning rather than transfer learning, so I set the fine_tune parameter to True. Theoretically speaking, fine-tuning is better than transfer learning if you have a relatively large dataset since it works by re-adjusting the weights of the entire model. However, if your dataset is rather small, you should go with transfer learning instead – but that’s just the theory. It’s definitely a good idea to experiment with both options to see which works best in your case.

Still with the above codeblock, we configure the forward() method to accept image-caption pairs as input. With this configuration, we basically design this method such that it can only be used for training purpose. Here we initially process the raw image with the GoogLeNet inside the encoder block (#(3)). Afterwards, we pass the extracted features as well as the tokenized captions into the decoder block and let it produce another token sequence (#(4)). In the actual training, this caption output will then be compared with the ground truth to compute the error. This error value is going to be used to compute gradients through backpropagation, which determines how the weights in the network are updated.

It is important to know that we cannot use the forward() method to perform inference, so we need a separate one for that. In this case, I am going to implement the code specifically to perform inference in the generate() method below.

# Codeblock 8b
    def generate(self, images):  #(1)
        features = self.encoder(images)              #(2)
        print(f"after encodertt: {features.shape}n")

        words = []  #(3)
        for i in range(SEQ_LENGTH):                  #(4)
            print(f"iteration #{i}")
            features = features.unsqueeze(1)
            print(f"after unsqueezett: {features.shape}")

            features, _ = self.decoder.lstm(features)
            print(f"after lstmtt: {features.shape}")

            features = features.squeeze(1)           #(5)
            print(f"after squeezett: {features.shape}")

            probs = self.decoder.linear(features)    #(6)
            print(f"after lineartt: {probs.shape}")

            _, word = probs.max(dim=1)  #(7)
            print(f"after maxtt: {word.shape}")

            words.append(word.item())  #(8)

            if word == 1:  #(9)
                break

            features = self.decoder.embedding(word)  #(10)
            print(f"after embeddingtt: {features.shape}n")

        return words       #(11)

Instead of taking two inputs like the previous one, the generate() method takes raw image as the only input (#(1)). Since we want the features extracted from the image to be the initial input token, we first need to process the raw input image with the encoder block prior to actually generating the subsequent tokens (#(2)). Next, we allocate an empty list for storing the token sequence to be produced later (#(3)). The tokens themselves are generated one by one, so we wrap the entire process inside a for loop, which is going to stop iterating once it reaches at most 30 (SEQ_LENGTH) words (#(4)).

The steps done inside the loop is algorithmically similar to the ones we discussed earlier. However, since the LSTM cell here generates a single token at a time, the process requires the tensor to be treated a bit differently from the one passed through the forward() method of the LSTMDecoder class back in Codeblock 6b. The first difference you might notice is the squeeze() operation (#(5)), which is basically just a technical step to be done such that the subsequent layer does the linear projection correctly (#(6)). Then, we take the index of the feature vector having the highest value, which corresponds to the token most likely to come next (#(7)), and append it to the list we allocated earlier (#(8)). The loop is going to break whenever the predicted index is a stop token, which in this case I assume that this token is at the 1st index of the probs vector. Otherwise, if the model does not find the stop token, then it is going to convert the last predicted word into its 512 (EMBED_DIM)-dimensional vector (#(10)), allowing it to be used as the input features for the next iteration. Lastly, the generated word sequence will be returned once the loop is completed (#(11)).

We are going to simulate the forward pass for the training phase using the Codeblock 9 below. Here I pass two tensors through the show_and_tell model (#(1)), each representing a raw image of size 3×224×224 (#(2)) and a sequence of tokenized words (#(3)). Based on the resulting output, we found that our model works properly as the two input tensors successfully passed through the InceptionEncoder and the LSTMDecoder part of the network.

# Codeblock 9
show_and_tell = ShowAndTell()  #(1)

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)  #(2)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))      #(3)

captions = show_and_tell(images, captions)
# Codeblock 9 Output
after encoder : torch.Size([1, 512])
after decoder : torch.Size([1, 31, 10000])

Now, let’s assume that our show_and_tell model is already trained on an image captioning dataset, and thus ready to be used for inference. Look at the Codeblock 10 below to see how I do it. Here we set the model to eval() mode (#(1)), initialize the input image (#(2)), and pass it through the model using the generate() method (#(3)).

# Codeblock 10
show_and_tell.eval()  #(1)

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)  #(2)

with torch.no_grad():
    generated_tokens = show_and_tell.generate(images)  #(3)

The flow of the tensor can be seen in the output below. Here I truncate the resulting outputs because it only shows the same token generation process 30 times.

# Codeblock 10 Output
after encoder    : torch.Size([1, 512])

iteration #0
after unsqueeze  : torch.Size([1, 1, 512])
after lstm       : torch.Size([1, 1, 512])
after squeeze    : torch.Size([1, 512])
after linear     : torch.Size([1, 10000])
after max        : torch.Size([1])
after embedding  : torch.Size([1, 512])

iteration #1
after unsqueeze  : torch.Size([1, 1, 512])
after lstm       : torch.Size([1, 1, 512])
after squeeze    : torch.Size([1, 512])
after linear     : torch.Size([1, 10000])
after max        : torch.Size([1])
after embedding  : torch.Size([1, 512])

.
.
.
.

To see what the resulting caption looks like, we can just print out the generated_tokens list as shown below. Keep in mind that this sequence is still in the form of tokenized words. Later, in the post-processing stage, we will need to convert them back to the words corresponding to these numbers.

# Codeblock 11
generated_tokens
# Codeblock 11 Output
[5627,
 3906,
 2370,
 2299,
 4952,
 9933,
 402,
 7775,
 602,
 4414,
 8667,
 6774,
 9345,
 8750,
 3680,
 4458,
 1677,
 5998,
 8572,
 9556,
 7347,
 6780,
 9672,
 2596,
 9218,
 1880,
 4396,
 6168,
 7999,
 454]

Ending

With the above output, we’ve reached the end of our discussion on image captioning. Over time, many other researchers attempted to make improvements to accomplish this task. So, I think in the upcoming article I will discuss the state-of-the-art method on this topic.

Thanks for reading, I hope you learn something new today!

_By the way you can also find the code used in this article here._


References

[1] Oriol Vinyals et al. Show and Tell: A Neural Image Caption Generator. Arxiv. https://arxiv.org/pdf/1411.4555 [Accessed November 13, 2024].

[2] Christian Szegedy et al. Going Deeper with Convolutions. Arxiv. https://arxiv.org/pdf/1409.4842 [Accessed November 13, 2024].

[3] Marc Tanti et al. Where to put the Image in an Image Caption Generator. Arxiv. https://arxiv.org/pdf/1703.09137 [Accessed November 13, 2024].

[4] Stepan Ulyanin. Captioning Images with CNN and RNN, using PyTorch. Medium. https://medium.com/@stepanulyanin/captioning-images-with-pytorch-bc592e5fd1a3 [Accessed November 16, 2024].

[5] Saketh Kotamraju. How to Build an Image-Captioning Model in Pytorch. Towards Data Science. https://towardsdatascience.com/how-to-build-an-image-captioning-model-in-pytorch-29b9d8fe2f8c [Accessed November 16, 2024].

[6] Code with Aarohi. Image Captioning using CNN and RNN | Image Captioning using Deep Learning. YouTube. https://www.youtube.com/watch?v=htNmFL2BG34 [Accessed November 16, 2024].

The post Show and Tell appeared first on Towards Data Science.

]]>
Meet GPT, The Decoder-Only Transformer https://towardsdatascience.com/meet-gpt-the-decoder-only-transformer-12f4a7918b36/ Mon, 06 Jan 2025 17:01:43 +0000 https://towardsdatascience.com/meet-gpt-the-decoder-only-transformer-12f4a7918b36/ Understanding and implementing the GPT-1, GPT-2 and GPT-3 architectures

The post Meet GPT, The Decoder-Only Transformer appeared first on Towards Data Science.

]]>
Photo by Emiliano Vittoriosi on Unsplash
Photo by Emiliano Vittoriosi on Unsplash

Introduction

Large Language Models (LLMs), such as ChatGPT, Gemini, Claude, etc., have been around for a while now, and I believe all of us already used at least one of them. As this article is written, ChatGPT already implements the fourth generation of the Gpt-based model, named GPT-4. But do you know what GPT actually is, and what the underlying neural network architecture looks like? In this article we are going to talk about GPT models, especially GPT-1, GPT-2 and GPT-3. I will also demonstrate how to code them from scratch with PyTorch so that you can get better understanding about the structure of these models.

A Brief History of GPT

Before we get into GPT, we need to understand the original Transformer architecture in advance. Generally speaking, a Transformer consists of two main components: the Encoder and the Decoder. The former is responsible for understanding input sequence, whereas the latter is used for generating another sequence based on the input. For example, in a question answering task, the decoder will produce an answer to the input sequence, while in a machine translation task it is used for generating the translation of the input.

Figure 1. The Transformer model. The block on the left is the Encoder and the one on the right is the Decoder [1].
Figure 1. The Transformer model. The block on the left is the Encoder and the one on the right is the Decoder [1].

The two main components of the Transformer mentioned above also consist of several sub-components, such as attention block, look-ahead mask, and layer normalization. Here I assume that you already have basic knowledge about them. If you haven’t, I highly recommend you read my previous post regarding the topic which you can access through the link I provided at the end of this article [2].

It was proven that Transformer has an impressive performance in language modeling. Interestingly, future researchers found that its encoder and decoder part can work individually to do so. This was actually the moment when BERT (Bidirectional Encoder Representation of Transformers) and GPT (Generative Pretrained Transformers) were invented, where BERT is basically just a stack of encoders, while GPT is a stack of decoders.

Talking more specifically about GPT, its first version (GPT-1) was released by OpenAI back in 2018. This was then followed by GPT-2 and GPT-3 in 2019 and 2020, respectively. However, there were not so many people knew about GPT at the moment since it was only usable via an API. It wasn’t until 2022 when OpenAI released ChatGPT with the GPT-3.5 backend which allows public to interact with this LLM easily. Below is a figure showing the evolution of GPT models.

Figure 2. The evolution of GPT models over time [3].
Figure 2. The evolution of GPT models over time [3].

GPT-1

The first GPT version was published in a research paper titled "Improving Language Understanding by Generative Pre-Training" by Radford et al. [4] back in 2018. Previously I’ve mentioned that GPT is basically just a stack of decoders, and in the case of GPT-1 the decoder block is repeated 12 times. It is important to keep in mind that the decoder architecture implemented in GPT-1 is not completely identical with the one in the original Transformer. In in the following figure, the model on the left is the decoder proposed in the GPT-1 paper, whereas the one on the right is the decoder part of the original Transformer. Here we can see that the part highlighted in red in the original decoder does not exist in GPT-1. This is essentially because this component is employed to combine the information coming from the encoder and from the decoder input itself. In the case of GPT-1, since we don’t have the encoder part, hence we can just omit it.

Figure 3. The GPT-1 architecture (left) [4] and the Decoder part of the original Transformer architecture [5].
Figure 3. The GPT-1 architecture (left) [4] and the Decoder part of the original Transformer architecture [5].

GPT-1 Pretraining

The training process of the GPT-1 model is divided into two steps: pretraining and fine-tuning. The goal of pretraining is to teach the model to predict the next token in a sequence based on the preceding tokens – a process commonly known as language modeling. This pretraining step uses a self-supervised mechanism, i.e., a training process where the label comes from the dataset itself. With this method, we don’t need to perform manual labeling. Instead, we can just chunk 513 tokens at random positions from a long text, setting the first 512 as the features and the last one as the label. This number of tokens is chosen based on the context window parameter of GPT-1, which by default is set to 512. In addition to the tokenization mechanism, GPT-1 uses BPE (Byte Pair Encoding). This essentially means that every single token does not necessarily correspond to a single word. Rather, it can also be a sub-word or even an individual letter.

The GPT-2 pretraining is done using the objective function shown in Figure 4 below, where uᵢ is the token being predicted, uᵢ₋ₖ, …, uᵢ₋₁, are the k previous tokens (context window), and Θ is the model parameters. What’s essentially done by this equation is that it computes the likelihood of a token occurring given the previous tokens in the sequence. The token with the highest probability will be returned as the predicted output. By doing this process iteratively, the model will continue the text provided in the prompt. If we go back to Figure 3, we will see that the GPT-1 model has two heads: text prediction and task classifier. Later on, this text generation process is going to be done using the text prediction head.

Figure 4. The objective function for pretraining [4].
Figure 4. The objective function for pretraining [4].

GPT-1 Fine-Tuning

Even though by default GPT is a generative model, but during the fine-tuning phase we treat it as a discriminative model. This is essentially because in this phase the goal is just to perform a typical classification task. In the following objective function, y represents the class to be predicted, while x¹, …, xᵐ denote m input tokens in sequence x. We can simply think of this equation like we want to categorize a text into a specific class. Such a classification mechanism will later be used to perform varying downstream tasks, which I will explain very soon.

Figure 5. The objective function for the downstream classification task [4].
Figure 5. The objective function for the downstream classification task [4].

There are four different downstream tasks experimented in the paper: classification, natural language inference (entailment), sentence similarity, and multiple-choice question answering. The figure below illustrates the workflow of these tasks.

Figure 6. The downstream task workflows of the GPT-1 model [4].
Figure 6. The downstream task workflows of the GPT-1 model [4].

The Transformer blocks colored in green are GPT-1 models, each having the exact same architecture. In order to allow the model to perform different tasks, we need to arrange the input texts accordingly. For a standard text classification task, e.g., sentiment analysis or document classification, we can simply put the token sequence between the start and extract token to mark the beginning and the end of a text before feeding it into the GPT-1 model. The resulting tensor will then be forwarded to a linear layer, which each neuron in the layer corresponds to a single class.

Figure 7. Examples of input texts and the corresponding labels for sentiment analysis (classification) task [3].
Figure 7. Examples of input texts and the corresponding labels for sentiment analysis (classification) task [3].

For textual entailment, the model accepts premise and hypothesis as a single sequence, separated by a delimiter token. In this case, the Task Classifier head is responsible for classifying whether the hypothesis entails the premise.

Figure 8. Examples of input texts and the corresponding labels for textual entailment task [3].
Figure 8. Examples of input texts and the corresponding labels for textual entailment task [3].

In the case of text similarity task, the model works by accepting two texts to be compared in two different orders: text 1 followed by text 2, and text 2 followed by text 1. These two sequences are fed into the GPT model in parallel, which the resulting outputs are then summed before eventually predicted whether they are similar. Or, we can also configure the output layer to perform a regression task, returning a continuous similarity score.

Figure 9. Example of a dataset for text similarity measurement [3].
Figure 9. Example of a dataset for text similarity measurement [3].

Lastly, for multiple-choice question answering we wrap both the text containing facts and the corresponding question inside the context block. Next, we place a delimiter token before appending one of the answers to it. We do the same thing for all possible answers for every question. With this dataset structure, we perform inference by passing them into the model, letting it calculate the similarity score between each question-answer pair. This score indicates how well each answer addresses the question based on the given facts. We can basically think of this like a standard classification task, where the selected answer is the one having the highest similarity score.

Figure 10. An example of a dataset for multiple-choice question answering task [3].
Figure 10. An example of a dataset for multiple-choice question answering task [3].

During the fine-tuning phase, we don’t completely ignore the language modeling process as it still gives some ideas regarding what token should come next. In other words, we can perceive it as an auxiliary objective, which is useful for accelerating convergence while at the same time improving the generalization of the classifier model. Therefore, the downstream task objective function (L2) needs to be combined with the language modeling objective function (L1). The Figure 11 below shows how it is expressed in a formal mathematical definition, where the weight λ is typically set to be less than 1, allowing the model to pay more attention to the downstream task.

Figure 11. The objective function used for fine-tuning [4].
Figure 11. The objective function used for fine-tuning [4].

So, to sum up, the point of GPT-1 is that it basically works by continuing the preceding sequence. If we don’t further fine-tune the model, it will continue the sequence based on its understanding of the data provided in the self-supervised training phase. Meanwhile, if we perform fine-tuning, the model will also continue the sequence but only using the specific ground truths provided in the supervised learning phase.

GPT-1 Implementation: Look-Ahead Mask & Positional Encoding

As we already know the theory behind GPT-1, let’s now implement the architectural design from scratch! We are going to start by importing the required modules.

# Codeblock 1
import torch
import torch.nn as nn

Afterwards, we will continue with the parameter configuration, which you can see in Codeblock 2 below. All variables we set here are exactly the same as the ones specified in the GPT-1 paper, except for the BATCH_SIZE and N_CLASS (written at line marked with #(1) and #(2)). The BATCH_SIZE variable is necessary because PyTorch by default processes tensors in a batch regardless of the number of samples contained inside. In this case, I assume that there is only a single sample in each batch. Meanwhile, N_CLASS will be used for the task classifier head which will run when the downstream task is performed. As an example, here I set the parameter to 3. With this configuration, we can use the head for 3-class classification task like the sentiment analysis or the textual entailment cases I showed you earlier in Figure 7 and 8.

# Codeblock 2
BATCH_SIZE = 1          #(1)
N_CLASS    = 3          #(2)
SEQ_LENGTH = 512        #(3)
VOCAB_SIZE = 40000      #(4)

D_MODEL    = 768        #(5)
N_LAYERS   = 12         #(6)
NUM_HEADS  = 12         #(7)
HIDDEN_DIM = D_MODEL*4  #(8)
DROP_PROB  = 0.1        #(9)

The SEQ_LENGTH parameter (#(3)), which is another term to denote context window, is set to 512. The BPE tokenization mechanism performed on the training dataset produces 40,000 unique tokens, hence we need to use this number for VOCAB_SIZE (#(4)). Next, the D_MODEL parameter denotes the feature vector length used to represent a token, which in the case of GPT-1, this is set to 768 (#(5)). Previously I mentioned that the decoder layer is repeated 12 times. In the above code, this number is assigned to the N_LAYERS variable (#(6)). Each of the decoder layers themselves comprises some other components which the parameters need to be manually configured as well. Those parameters are the number of attention heads (#(7)), the number of hidden neurons in the feed forward block (#(8)), and the rate for the dropout layers (#(9)).

As the required parameters have been configured, the next thing to be done is initializing a function for creating the so-called look-ahead mask and a class for creating positional embedding. The look-ahead mask can be thought of as a tool that prevents the model from looking at the subsequent tokens during the training phase, considering that later in the inference phase, subsequent tokens are unavailable. Meanwhile, the positional embedding is used to label each token with specific numbers, which is useful to preserve information regarding the token orders. In fact, even though the look-ahead mask already contains this information, but the positional embedding emphasizes it even further.

Look at the Codeblock 3 and 4 below to see how I implement the two concepts I just explained. I am not going to get any deeper into them as I’ve provided the complete explanation in my article about Transformer, which the link is provided in the references list [2] – you can just click it and scroll all the way down to the Positional Encoding and the Look-Ahead Mask sections. Even the following codes are exactly the same as what I wrote there!

# Codeblock 3
def create_mask():
    mask = torch.tril(torch.ones((SEQ_LENGTH, SEQ_LENGTH)))
    mask[mask == 0] = -float('inf')
    mask[mask == 1] = 0
    return mask
# Codeblock 4
class PositionalEncoding(nn.Module):
    def forward(self):
        pos = torch.arange(SEQ_LENGTH).reshape(SEQ_LENGTH, 1)
        i = torch.arange(0, D_MODEL, 2)
        denominator = torch.pow(10000, i/D_MODEL)

        even_pos_embed = torch.sin(pos/denominator)
        odd_pos_embed  = torch.cos(pos/denominator)

        stacked = torch.stack([even_pos_embed, odd_pos_embed], dim=2)
        pos_embed = torch.flatten(stacked, start_dim=1, end_dim=2)

        return pos_embed

GPT-1 Implementation: Decoder

Now let’s talk about the decoder part which I implement inside the DecoderGPT1() class. The reason that I name it this way is because we are going to use it exclusively for GPT-1. See the detailed implementation in Codeblock 5a and 5b.

# Codeblock 5a
class DecoderGPT1(nn.Module):
    def __init__(self):
        super().__init__()

        self.multihead_attention = nn.MultiheadAttention(embed_dim=D_MODEL,  #(1)
                                                         num_heads=NUM_HEADS, 
                                                         batch_first=True)  #(2)
        self.dropout_0 = nn.Dropout(DROP_PROB)
        self.norm_0 = nn.LayerNorm(D_MODEL)  #(3)

        self.feed_forward = nn.Sequential(nn.Linear(D_MODEL, HIDDEN_DIM),  #(4) 
                                          nn.GELU(), 
                                          nn.Linear(HIDDEN_DIM, D_MODEL))
        self.dropout_1 = nn.Dropout(DROP_PROB)
        self.norm_1 = nn.LayerNorm(D_MODEL)  #(5)

        nn.init.normal_(self.feed_forward[0].weight, 0, 0.02)  #(6)
        nn.init.normal_(self.feed_forward[2].weight, 0, 0.02)  #(7)

There are several neural network layers I initialize in the __init__() method above, in which every single of those corresponds to each sub-component inside the decoder shown back in Figure 3. The first one is the multihead attention layer (#(1)), where the values used for embed_dim and num_heads are taken from the variables we initialized earlier. Additionally, here I set the batch_first parameter to True (#(2)) since our batch dimension is on the 0th axis, which is a common practice when it comes to working with PyTorch tensors. Next, we initialize two layer normalization layers with D_MODEL as the input argument for each (at line #(3) and #(5)). This essentially means that these two layers will perform normalization across the 768 values for each token.

To the feed forward block, I create it using nn.Sequential() (#(4)), where I initialize two linear layers and a GELU activation function in between. The first linear layer is responsible to expand the 768 (D_MODEL)-dimensional token representation into 3072 (HIDDEN_DIM) dimensions. Afterwards, we pass it through GELU before shrinking it back to 768 dimensions. The authors of this paper mentioned that the weight initialization for these layers follows a normal distribution with the mean and standard deviation of 0 and 0.02, respectively. We can manually configure them using the code at line #(6) and #(7).

Now let’s move on to Codeblock 5b where I define the forward() method of the DecoderGPT1() class. You can see below that it works by accepting two inputs: x and attn_mask (#(1)). The first input is the embedded token sequence, while the second one is the look-ahead mask generated by the create_mask() function we defined earlier.

# Codeblock 5b
    def forward(self, x, attn_mask):  #(1)
        residual = x  #(2)
        print(f"original &amp; residualt: {x.shape}")

        x = self.multihead_attention(x, x, x, attn_mask=attn_mask)[0]  #(3)
        print(f"after attentiontt: {x.shape}")

        x = self.dropout_0(x)  #(4)
        print(f"after dropouttt: {x.shape}")

        x = x + residual  #(5)
        print(f"after additiontt: {x.shape}")

        x = self.norm_0(x)  #(6)
        print(f"after normalizationt: {x.shape}")

        residual = x
        print(f"nx &amp; residualtt: {x.shape}")

        x = self.feed_forward(x)  #(7)
        print(f"after feed forwardt: {x.shape}")

        x = self.dropout_1(x)
        print(f"after dropouttt: {x.shape}")

        x = x + residual
        print(f"after additiontt: {x.shape}")

        x = self.norm_1(x)
        print(f"after normalizationt: {x.shape}")

        return x

Before doing anything, the first thing we do inside the forward() method above is to store the original input tensor x to the residual variable (#(2)). The x tensor itself is then processed with the multihead attention layer (#(3)). Since we are about to perform self attention (not a cross attention), hence the query, key and value inputs for the layer are all derived from x. Not only that, here we also need to pass the look-ahead mask as the argument for the attn_mask parameter. After processing with the attention layer is complete, we will then pass the x tensor through a dropout layer (#(4)) before it is eventually combined again with residual (#(5)) and normalized by layer norm (#(6)). The remaining processes are nearly the same, except that we replace the self.multihead_attention layer with the self.feed_forward layer (#(7)).

To check if our decoder works properly, we can pass a tensor with the size of 1×512×768 as shown in Codeblock 6 below. This simulates a sequence of 512 tokens, each represented as a 768-dimensional vector.

# Codeblock 6
decoder_gpt_1 = DecoderGPT1()
x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
look_ahead_mask = create_mask()

x = decoder_gpt_1(x, look_ahead_mask)

We can see in the resulting output that this tensor successfully passed through all components in the decoder. It is worth noting that the tensor dimensions remain the same at each process, including the final output. This property allows us to stack multiple decoders without worrying that the tensor dimensions will break. – Well, in fact, there are actually some dimensionality changes inside the attention and the feed forward layer, but it immediately returns back to its original dimension before being fed into the subsequent layers.

# Codeblock 6 output
original &amp; residual   : torch.Size([1, 512, 768])
after attention       : torch.Size([1, 512, 768])
after dropout         : torch.Size([1, 512, 768])
after addition        : torch.Size([1, 512, 768])
after normalization   : torch.Size([1, 512, 768])

x &amp; residual          : torch.Size([1, 512, 768])
after feed forward    : torch.Size([1, 512, 768])
after dropout         : torch.Size([1, 512, 768])
after addition        : torch.Size([1, 512, 768])
after normalization   : torch.Size([1, 512, 768])

GPT-1 Implementation: Decoder with Input & Text Prediction

As we have completed the decoder block, we will now connect the input layer before it and attach the text prediction head to the output. You can see how I implement them in the GPT1() class below.

# Codeblock 7a
class GPT1(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding = nn.Embedding(num_embeddings=VOCAB_SIZE, 
                                            embedding_dim=D_MODEL)  #(1)

        self.positional_encoding = PositionalEncoding()  #(2)

        self.decoders = nn.ModuleList([DecoderGPT1() for _ in range(N_LAYERS)])  #(3)

        self.linear = nn.Linear(in_features=D_MODEL, out_features=VOCAB_SIZE)  #(4)

        nn.init.normal_(self.token_embedding.weight, mean=0, std=0.02)  #(5)
        nn.init.normal_(self.linear.weight, mean=0, std=0.02)  #(6)

Inside the __init__() method, we first initialize an nn.Embedding() layer. This layer is used to map each token into 768 (D_MODEL)-dimensional vector (#(1)). Secondly, we initialize a positional encoding tensor using the PositionalEncoding() class we created earlier (#(2)). The 12 decoder layers need to be initialized one by one, and in this case I do it using a simple for loop. All these decoders are then stored in self.decoders (#(3)). Next, we initialize a linear layer, which basically corresponds to the text prediction head (#(4)). This layer is responsible to map each vector into VOCAB_SIZE (40,000) number of neurons, where every single of those indicates the probability of a specific token being selected. Again, here I also manually configure the weight initialization distribution using the code at line #(5) and #(6).

Moving on to the forward() method in Codeblock 7b, the first thing we do is processing the input tensor with the self.token_embedding layer (#(1)). Next, we inject the positional encoding tensor into x by element-wise addition (#(2)). The resulting tensor is then forwarded to the stack of 12 decoders, which we can do with another loop as shown at line #(3). Remember that the GPT-1 model has two heads. In this case, the text prediction head will be included inside the forward() method, whereas the task classifier head will later be implemented in separate class. To accomplish this, I will return both the raw decoder output (decoder_output) as well as the next-word prediction output (text_output) as shown at line #(5). Later on, I will use decoder_output as the input for the task classifier head.

# Codeblock 7b
    def forward(self, x):
        print(f"original inputtt: {x.shape}")

        x = self.token_embedding(x.long())  #(1)
        print(f"embedded tokenstt: {x.shape}")

        x = x + self.positional_encoding()  #(2)
        print(f"after additiontt: {x.shape}")

        for i, decoder in enumerate(self.decoders):
            x = decoder(x, attn_mask=look_ahead_mask)  #(3)
            print(f"after decoder #{i}t: {x.shape}")

        decoder_output = x  #(4)
        print(f"decoder_outputtt: {decoder_output.shape}")

        text_output = self.linear(x)
        print(f"text_outputtt: {text_output.shape}")

        return decoder_output, text_output  #(5)

We can check if our GPT1() class works properly with the Codeblock 8 below. The x tensor here is assumed as a sequence of tokens with the length of SEQ_LENGTH (512), in which every single of the element is a random integer within the range of 0 to VOCAB_SIZE (40,000), representing the encoded tokens.

# Codeblock 8
gpt1 = GPT1()

x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
x = gpt1(x)
# Codeblock 8 output
original input     : torch.Size([1, 512])  #(1)
embedded tokens    : torch.Size([1, 512, 768])  #(2)
after addition     : torch.Size([1, 512, 768])
after decoder #0   : torch.Size([1, 512, 768])
after decoder #1   : torch.Size([1, 512, 768])
after decoder #2   : torch.Size([1, 512, 768])
after decoder #3   : torch.Size([1, 512, 768])
after decoder #4   : torch.Size([1, 512, 768])
after decoder #5   : torch.Size([1, 512, 768])
after decoder #6   : torch.Size([1, 512, 768])
after decoder #7   : torch.Size([1, 512, 768])
after decoder #8   : torch.Size([1, 512, 768])
after decoder #9   : torch.Size([1, 512, 768])
after decoder #10  : torch.Size([1, 512, 768])
after decoder #11  : torch.Size([1, 512, 768])
decoder_output     : torch.Size([1, 512, 768])  #(3)
text_output        : torch.Size([1, 512, 40000])  #(4)

Based on the above output, we can see that our self.token_embedding layer successfully converted the sequence of 512 tokens (#(1)) into a sequence of 768-dimensional token vectors (#(2)). This tensor dimension remained the same all the way to the last decoder layer, which the output was then stored in the decoder_output variable (#(3)). Finally, after being processed with the task classifier head, the tensor dimension changed to 1×512×40000 (#(4)), containing the information regarding the next-token prediction. – In the original Transformer, this is often called shifted-right output. It basically means that the information stored in the 0th row is the prediction for the 1st token, the 1st row contains the prediction for 2nd token, and so on. Hence, since we want to predict the 513th token, we can simply take the last (512th) row and select the element corresponding to the token with the highest probability.

To calculate the number of model parameters, we can use the count_parameters() function below.

# Codeblock 9
def count_parameters(model):
    return sum([params.numel() for params in model.parameters()])

count_parameters(gpt1)
# Codeblock 9 output
146534464

We can see here that our GPT-1 implementation has approximately 146 million number of params. – I do need to acknowledge that this number is different to the one disclosed in the original paper, i.e., 117 million. Such a difference might probably be because I missed some intricate details. Feel free to comment if you know which part of the code I should change to achieve this number!

GPT-1 Implementation: Task Classifier Head

Remember that our GPT1() class only includes the text prediction head. For language modeling alone, this is already sufficient, yet for fine-tuning, we need to manually create the task classifier head. Look at the Codeblock 10 below to see how I implement it.

# Codeblock 10
class TaskClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear = nn.Linear(in_features=D_MODEL, out_features=N_CLASS)  #(1)
        nn.init.normal_(self.linear.weight, mean=0, std=0.02)

    def forward(self, x):  #(2)
        print(f"decoder_outputt: {x.shape}")

        class_output = self.linear(x)
        print(f"class_outputt: {class_output.shape}")

        return class_output

Similar to text prediction, the task classifier head is basically just a linear layer as well. However, in this case, it maps every 768-dimensional token embedding into 3 (N_CLASS) output values corresponding to the number of classes for the classification task we want to train it on (#(1)). Later on, the output from the decoder will be used as the input for the forward() method (#(2)). Thus, to test this TaskClassifier() class, I will pass through a dummy tensor which the dimension exactly matches with the decoder output, i.e., 1×512×768. We can see in the Codeblock 11 below that this tensor successfully passes through the task classifier head.

# Codeblock 11
task_classifier = TaskClassifier()

x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x = task_classifier(x)
# Codeblock 11 output
decoder_output : torch.Size([1, 512, 768])
class_output   : torch.Size([1, 512, 3])  #(1)

If we take a closer look at the above output, we can see that the resulting tensor is now having the shape of 1×512×3 (#(1)). This essentially means that every single token is now represented as 3 numbers. As mentioned earlier, in this example we are about to simulate a sentiment analysis task with 3 classes: positive, negative and neutral. To determine the sentiment of the entire sequence, we can either aggregate the logits across all tokens or use only the logits from the last token (considering that it already contains information from the entire sequence). Additionally, with the same output tensor shape, we can use the similar idea to perform token-level classification task, such as NER (Named Entity Recognition) or POS (Part-of-Speech) tagging.

Later in the inference phase, we will use the TaskClassifier() head every time we want to perform a specific downstream task. The Codeblock 12 below is a sample code to perform the forward pass. What it essentially does is that we pass the tokenized sentence into the gpt1 model, which returns the raw decoder output and the next-word prediction (#(1)). Then, we use the output from the decoder as the input for the task classifier head, which will return the logits of the available classes (#(2)).

# Codeblock 12
def gpt1_fine_tune(x, gpt1, task_classifier):
    print(f"original inputtt: {x.shape}")

    decoder_output, text_output = gpt1(x)  #(1)
    print(f"decoder_outputtt: {decoder_output.shape}")
    print(f"text_outputtt: {text_output.shape}")

    class_output = task_classifier(decoder_output)  #(2)
    print(f"class_outputtt: {class_output.shape}")

    return text_output, class_output

Based on the output produced by the following codeblock, we can see that our gpt1_fine_tune() function above works properly.

# Codeblock 13
gpt1 = GPT1()
task_classifier = TaskClassifier()

x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
text_output, class_output = gpt1_fine_tune(x, gpt1, task_classifier)
# Codeblock 13 output
original input  : torch.Size([1, 512])
decoder_output  : torch.Size([1, 512, 768])
text_output     : torch.Size([1, 512, 40000])
class_output    : torch.Size([1, 512, 3])

GPT-1 Limitations

Despite obtaining remarkable results in handling the four downstream tasks I showed in Figure 6, it is important to know that this approach has some drawbacks. First, the training procedure is complex since we need to perform pretraining and fine-tuning in separate processes. Second, since fine-tuning is a discriminative process, we still need to perform manual labeling (unlike the generative process for pretraining that uses self-supervised labeling method). Third, the model is not flexible, as it can only work on the task it is fine-tuned on. For instance, a model specialized for sentiment analysis cannot be used for question answering task. – Fortunately, GPT-2 was then introduced soon after to handle these issues.


GPT-2

GPT-2 was introduced in the paper titled "Language Models are Unsupervised Multitask Learners" published several months after GPT-1 [6]. The authors of this paper found that the plain GPT language model could actually perform various downstream tasks without fine-tuning. It is possible to achieve this by modifying the objective function. If GPT-1 makes predictions based solely on the previous token sequence, i.e., P(output | input), GPT-2 does so not only based on the sequence, but also based on the given task, i.e., P(output | input, task). With this property, the same prompt will cause the model to produce different output whenever the given task is different. And interestingly, we can simply include the task in the prompt as a natural language.

As an example, if you prompt a model with "lorem ipsum dolor sit amet", it will likely continue with "consectetur adipiscing elit." But if you include a task like "what does it mean?" in the prompt, the model will give an explanation regarding what it actually is. I tried this in ChatGPT, and the answer was exactly what I expected.

Figure 12. ChatGPT only continues the input sentence if the task is not specified [3].
Figure 12. ChatGPT only continues the input sentence if the task is not specified [3].
Figure 13. An example of how assigning a specific task causes the model to respond differently [3].
Figure 13. An example of how assigning a specific task causes the model to respond differently [3].

The idea of providing the task in form of natural language can be achieved by training the model with an enormous amount of text in self-supervised manner. For the sake of comparison, the dataset used for GPT-1 to perform language modeling is the BooksCorpus dataset, in which it contains more than 7000 unpublished books and is equivalent to approximately 5 GB of text. Meanwhile, the dataset used for GPT-2 is WebText which has the size of approximately 40 GB. Not only the dataset, but the model itself is also larger. The author of the GPT-2 paper created four model variations, each having different configurations as summarized in Figure 14 below. The one in the first row is equivalent with the GPT-1 paper we just implemented, whereas the model recognized as GPT-2 is the one in the last row. Here we can see that GPT-2 is roughly 13 times larger than GPT-1 in terms of the number of parameters. Based on this information regarding the dataset and model size, we can definitely expect GPT-2 to perform much better than its predecessor.

Figure 14. The four model variations proposed in the GPT-2 paper [6].
Figure 14. The four model variations proposed in the GPT-2 paper [6].

It is important to know that N_LAYERS and D_MODEL are not the only parameters we need to change if we were to actually create the model. The codeblock below shows the complete parameter configuration for GPT-2.

# Codeblock 14
BATCH_SIZE = 1
SEQ_LENGTH = 1024  #(1)
VOCAB_SIZE = 50257  #(2)

D_MODEL    = 1600
NUM_HEADS  = 25  #(3)
HIDDEN_DIM = D_MODEL*4  #(4)
N_LAYERS   = 48
DROP_PROB  = 0.1

In this GPT version, instead of only taking into account 512 tokens for predicting next token, authors extend it further to 1024 (#(1)) so that it can now attend and process longer token sequence, allowing the model to accept longer prompts. The vocabulary size also gets larger. Previously in GPT-1, the number of unique tokens was only 40,000, but in GPT-2 this number increased to 50,257 (#(2)). The last thing we need to change is the number of attention heads, which is now set to 25 as shown at line #(3). The HIDDEN_DIM parameter actually also changes, but we don’t need to manually specify the value for this as it remains configured to be 4 times larger than the embedding dimension (#(4)).

GPT-2 Implementation: Decoder

Talking about the architecture implementation, it is important to know that the decoder used in GPT-2 is somewhat different from the one used in GPT-1. In the case of GPT-2, we use the so-called pre-normalization, as opposed to GPT-1 that uses post-normalization. The idea of pre-normalization is that we place layer norm before the main operation is performed, i.e., the multihead attention and the feed forward blocks. You can see the illustration in the following figure.

Figure 15. The GPT-1 architecture without Task Classifier head (left) and the GPT-2 architecture (right) [3].
Figure 15. The GPT-1 architecture without Task Classifier head (left) and the GPT-2 architecture (right) [3].

I implement the decoder for GPT-2 in the DecoderGPT23() class below. Spoiler alert: I named it this way because the structure of the GPT-2 and GPT-3 architectures is exactly the same.

# Codeblock 15
class DecoderGPT23(nn.Module):
    def __init__(self):
        super().__init__()

        self.norm_0 = nn.LayerNorm(D_MODEL)
        self.multihead_attention = nn.MultiheadAttention(embed_dim=D_MODEL, 
                                                         num_heads=NUM_HEADS, 
                                                         batch_first=True)
        self.dropout_0 = nn.Dropout(DROP_PROB)

        self.norm_1 = nn.LayerNorm(D_MODEL)
        self.feed_forward = nn.Sequential(nn.Linear(D_MODEL, HIDDEN_DIM), 
                                          nn.GELU(), 
                                          nn.Linear(HIDDEN_DIM, D_MODEL))
        self.dropout_1 = nn.Dropout(DROP_PROB)

        nn.init.normal_(self.feed_forward[0].weight, 0, 0.02)
        nn.init.normal_(self.feed_forward[2].weight, 0, 0.02)

    def forward(self, x, attn_mask):
        residual = x
        print(f"original &amp; residualt: {x.shape}")

        x = self.norm_0(x)
        print(f"after normalizationt: {x.shape}")

        x = self.multihead_attention(x, x, x, attn_mask=attn_mask)[0]
        print(f"after attentiontt: {x.shape}")

        x = self.dropout_0(x)
        print(f"after dropouttt: {x.shape}")

        x = x + residual
        print(f"after additiontt: {x.shape}")

        residual = x
        print(f"nx &amp; residualtt: {x.shape}")

        x = self.norm_1(x)
        print(f"after normalizationt: {x.shape}")

        x = self.feed_forward(x)
        print(f"after feed forwardt: {x.shape}")

        x = self.dropout_1(x)
        print(f"after dropouttt: {x.shape}")

        x = x + residual
        print(f"after additiontt: {x.shape}")

        return x

Well, I don’t think I need to explain the above code any further since it is mostly the same as the decoder for GPT-1, except that here we place the layer normalization blocks at different positions. So, now we will jump directly into the testing code. See the Codeblock 16 below.

# Codeblock 16
decoder_gpt_2 = DecoderGPT23()
x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
look_ahead_mask = create_mask()

x = decoder_gpt_2(x, look_ahead_mask)

We can see in the resulting output that our x tensor successfully passed through all sub-components inside the decoder layer.

# Codeblock 16 output
original &amp; residual : torch.Size([1, 1024, 1600])
after normalization : torch.Size([1, 1024, 1600])
after attention     : torch.Size([1, 1024, 1600])
after dropout       : torch.Size([1, 1024, 1600])
after addition      : torch.Size([1, 1024, 1600])

x &amp; residual        : torch.Size([1, 1024, 1600])
after normalization : torch.Size([1, 1024, 1600])
after feed forward  : torch.Size([1, 1024, 1600])
after dropout       : torch.Size([1, 1024, 1600])
after addition      : torch.Size([1, 1024, 1600])

GPT-2 Implementation: Decoder with Input & Text Prediction

Although the decoder used in GPT-2 is different from the one used in GPT-1, yet the other components, namely positional encoding and the look-ahead mask, remain the same. Hence, we can just reuse them. The code used to attach these two components is mostly the same, but there are still some intricate details to pay attention to in Codeblock 17 below. First, here we initialize another layer normalization layer at line #(1) before placing it in the flow at line #(2). This is essentially done because in GPT-2 we have another layer norm block placed outside the decoder, which previously does not exist in GPT-1 (see Figure 15). Secondly, it is not necessary to store the raw decoder output like what we did in the GPT1() class (at line #(4) Codeblock 7b). This is basically because GPT-2 does not require fine-tuning to perform any kind of downstream tasks. Rather, it will rely solely on the task prediction head to do so.

# Codeblock 17
class GPT23(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding = nn.Embedding(num_embeddings=VOCAB_SIZE, 
                                            embedding_dim=D_MODEL)

        self.positional_encoding = PositionalEncoding()

        self.decoders = nn.ModuleList([DecoderGPT23() for _ in range(N_LAYERS)])

        self.norm_final = nn.LayerNorm(D_MODEL)  #(1)

        self.linear = nn.Linear(in_features=D_MODEL, out_features=VOCAB_SIZE)

        nn.init.normal_(self.token_embedding.weight, mean=0, std=0.02)
        nn.init.normal_(self.linear.weight, mean=0, std=0.02)

    def forward(self, x):
        print(f"original inputtt: {x.shape}")

        x = self.token_embedding(x.long())
        print(f"embedded tokenstt: {x.shape}")

        x = x + self.positional_encoding()
        print(f"after additiontt: {x.shape}")

        for i, decoder in enumerate(self.decoders):
            x = decoder(x, attn_mask=look_ahead_mask)
            print(f"after decoder #{i}t: {x.shape}")

        x = self.norm_final(x)  #(2)
        print(f"after final normt: {x.shape}")

        text_output = self.linear(x)
        print(f"text_outputtt: {text_output.shape}")

        return text_output

Now that we can test the GPT23() class above with the following codeblock. Here I test it with a sequence of tokens of length 1024. The resulting output is very long since we have the decoder layer repeated 48 times.

# Codeblock 18
gpt2 = GPT23()

x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
x = gpt2(x)
# Codeblock 18 output
original input    : torch.Size([1, 1024])
embedded tokens   : torch.Size([1, 1024, 1600])
after addition    : torch.Size([1, 1024, 1600])
after decoder #0  : torch.Size([1, 1024, 1600])
after decoder #1  : torch.Size([1, 1024, 1600])
after decoder #2  : torch.Size([1, 1024, 1600])
after decoder #3  : torch.Size([1, 1024, 1600])
.
.
.
.
after decoder #44 : torch.Size([1, 1024, 1600])
after decoder #45 : torch.Size([1, 1024, 1600])
after decoder #46 : torch.Size([1, 1024, 1600])
after decoder #47 : torch.Size([1, 1024, 1600])
after final norm  : torch.Size([1, 1024, 1600])
text_output       : torch.Size([1, 1024, 50257])

If we try to print out the number of parameters, we can see that GPT-2 has around 1.6 billion. Just like the GPT-1 implementation we did earlier, this number of parameters is also slightly different to the one disclosed in the paper, which is around 1.5 billion as shown in Figure 14.

# Codeblock 19
count_parameters(gpt2)
# Codeblock 19 output
1636434257

GPT-3

GPT-3 was proposed in the paper titled "Language Models are Few-Shot Learners" which was published back in 2020 [7]. This title signifies that the proposed model is able to perform a wide range of tasks given only several examples, a.k.a. "shots." Despite this emphasis on few-shot learning, in practice this model is also able to perform one-shot or even zero-shot learning. In case you’re not yet familiar with few-shot learning, it is basically a method to adapt the model to a specific task using only a small number of examples. Even though the objective is similar to fine-tuning, but few-shot learning allows it to do so without updating model weights. In the case of GPT models, this can be achieved thanks to the presence of the attention mechanism, which allows the model to dynamically focus on the most relevant parts of the instruction and examples provided in the prompt. Similar to the improvements made from GPT-1 to GPT-2, the ability of GPT-3 to perform much better in few-shot learning than its predecessors are also due to the increased amount of training data used and the larger model size.

GPT-3 Implementation: Model Configuration & Architectural Design

You already read the spoiler, right? The architectural design of GPT-3 is exactly the same as GPT-2. What makes them different is only the model size, which we can adjust by using larger values for the parameters. The Codeblock 20 below shows the parameter configuration for GPT-3.

# Codeblock 20
BATCH_SIZE = 1
SEQ_LENGTH = 2048
VOCAB_SIZE = 50257

D_MODEL    = 12288
NUM_HEADS  = 96
HIDDEN_DIM = D_MODEL*4
N_LAYERS   = 96
DROP_PROB  = 0.1

As the above variables have been updated, we can simply run the following codeblock to initialize the GPT-3 model (#(1)) and pass a tensor representing a sequence of tokens through it (#(2)).

# Codeblock 21
gpt3 = GPT23()  #(1)

x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
x = gpt3(x)  #(2)

Unfortunately, I cannot run the above code due to the limited memory I have. I even tried to run it on Kaggle Notebook with 30 GB of memory, but the out-of-memory error persists. So, for this one, I cannot show you the number of parameters the model creates when it is initialized. However, it is mentioned in the paper that GPT-3 consists of around 175 billion parameters, which basically means that it’s more than 100 times larger than GPT-2, – so it makes sense now why it can only be run on an extremely large and powerful machine. Look at the figure below to see how GPT versions differ from each other.

Figure 16. Comparison of different GPT versions [3].
Figure 16. Comparison of different GPT versions [3].

Ending

That’s pretty much everything about the theory and the implementation of different GPT versions, especially GPT-1, GPT-2 and GPT-3. As this article is written, OpenAI hasn’t officially disclosed the architectural details for GPT-4, so we can’t reproduce it just yet. I hope OpenAI will publish the paper very soon!

Thank you for reading my article up to this point. I do appreciate your time, and I hope you learn something new here. Have a nice day!

_Note: you can also access the code used in this article here._

References

[1] Ashish Vaswani et al. Attention Is All You Need. Arxiv. https://arxiv.org/pdf/1706.03762 [Accessed October 31, 2024].

[2] Muhammad Ardi. Paper Walkthrough: Attention Is All You Need. Towards Data Science. https://medium.com/towards-data-science/paper-walkthrough-attention-is-all-you-need-80399cdc59e1 [Accessed November 4, 2024].

[3] Image originally created by author.

[4] Alec Radford et al. Improving Language Understanding by Generative Pre-Training. OpenAI. https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf [Accessed October 31, 2024].

[5] Image created originally by author based on [1].

[6] Alec Radford et al. Language Models are Unsupervised Multitask Learners. OpenAI. https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf [Accessed October 31, 2024].

[7] Top B. Brown et al. Language Models are Few-Shot Learners. Arxiv. https://arxiv.org/pdf/2005.14165 [Accessed October 31, 2024].

The post Meet GPT, The Decoder-Only Transformer appeared first on Towards Data Science.

]]>
Paper Walkthrough: Neural Style Transfer https://towardsdatascience.com/paper-walkthrough-neural-style-transfer-fc5c978cdaed/ Tue, 03 Dec 2024 13:01:13 +0000 https://towardsdatascience.com/paper-walkthrough-neural-style-transfer-fc5c978cdaed/ Turn your photos into paintings with deep learning - Implementing NST from scratch using PyTorch

The post Paper Walkthrough: Neural Style Transfer appeared first on Towards Data Science.

]]>
Introduction

Lately, the term "Generative AI" has become a trending topic around the world thanks to the release of the publicly available AI models, like ChatGPT, Gemini, Claude, etc. As we all know, their capabilities were initially limited to understanding and generating texts, but soon after, they got their ability to perform the same thing on images as well. Talking more specifically about generative models for image data, there are actually plenty number of model variations we can use, in which every single of those has their own purpose. So far, I already got some of my articles about generative AI for image data published in Medium, such as Autoencoder and Variational Autoencoder (VAE). In today’s article, I am going to talk about another fascinating generative algorithm: The Neural Style Transfer.

NST was first introduced in a paper titled "A Neural Algorithm of Artistic Style" written by Gatys et al. back in 2015 [1]. It is explained in the paper that their main objective is to transfer the artistic style of an image (typically a painting) onto a different image, hence the name "Style Transfer." Look at some examples in Figure 1 below, where the authors restyled the picture on the top left with different paintings.

Figure 1. Applying NST to the original image (top left) using styles from The Shipwreck of the Minotaur by J.M.W. Turner (top right), The Starry Night by Vincent van Gogh (bottom left), and Der Schrei by Edvard Munch (bottom right) [1].
Figure 1. Applying NST to the original image (top left) using styles from The Shipwreck of the Minotaur by J.M.W. Turner (top right), The Starry Night by Vincent van Gogh (bottom left), and Der Schrei by Edvard Munch (bottom right) [1].

The Idea Behind NST

The authors of this research explained that the content and the style of an image can be separated by CNN. This essentially implies that if we have two images, we can take the content from the first image and the artistic style from the second one. By combining them, we can obtain a new image that retains the content of the first image yet is painted in the style of the second image. The content and style separation performed in the initial step is possible to be done based on the fact that typically shallower layers in CNN focused on extracting low-level features, i.e., edges, corners, and textures, while deeper layers are responsible to capture higher-level features, i.e., a pattern that resembles a specific object. In fact, we can think of the low-level features as the style of an image, while the higher-level ones as the image content.

In order to exploit this behavior, we need to have three images: content image, style image, and generated image. Content image is the one that the style will be replaced with the artistic pattern from the style image. Neither content nor style image are actually modified in the process since these two images will act as the ground truths. The generated image, on the other hand, is the one that we are going to modify based on the content information from the content image and the style information from the style image. Initially, the generated image can either be a random noise or a clone of the content image. Later in the training process, we will gradually update the pixel values inside this image such that it minimizes its difference between both the content and style image.

NST Architecture

According to the paper, the backbone of NST is the VGG-19 model. The flow of the three images in the network can be seen in Figure 2 below.

Figure 2. The flow of the generated, style, and content images in the pretrained VGG19 model [2].
Figure 2. The flow of the generated, style, and content images in the pretrained VGG19 model [2].

The VGG-19 network above initially works by accepting our content, style and generated images simultaneously. The content image (blue) will be processed starting from the beginning of the network all the way to _conv42 layer. To the style image (green) we also pass it from the input layer, but for this one we will take the feature map from _conv11, _conv21, _conv31, _conv41, and _conv51. Similarly, the generated image (orange) is also passed through the network, and we will extract the feature maps from the same layers used for both the content and style image. Additionally, we can also see in the figure that all layers after _conv51 are not necessary to be implemented as our images will not go through these layers.

Content & Style Loss

There are two loss functions implemented in NST, namely content loss and style loss. As the name suggests, content loss is employed to calculate the difference between the content image and the generated image. By minimizing this loss, we will be able to preserve the content information of the content image within the generated image. In Figure 2 above, content loss will be applied to the feature maps produced by the blue and the corresponding orange arrow (the two arrows coming out from _conv42 layer). Meanwhile, style loss is applied to compute the difference between feature maps from the style image and the generated image, i.e., between the green and the corresponding orange arrows. With the style loss minimized, our generated image should look similar to the style image in terms of the artistic patterns.

Mathematically speaking, content loss can be defined using the equation displayed in Figure 3. In the equation, P represents the feature map corresponding to the content image p. Meanwhile, F is the feature map obtained from the generated image x. The input parameter l indicates that feature maps P and F are taken from the same layer, which in this case it refers to layer _conv42. – By the way, if you often work with regression models, you should be familiar with this equation since it is essentially just an MSE (Mean Squared Error).

Figure 3. The mathematical definition of content loss [1].
Figure 3. The mathematical definition of content loss [1].

As for the style loss, we can calculate it using the equation in Figure 4. This equation sums the style loss E at each layer l with a weighting factor w.

Figure 4. Equation for calculating the overall style loss [1].
Figure 4. Equation for calculating the overall style loss [1].

The style loss of each layer itself is defined in the equation in Figure 5, where it is actually just another MSE function specifically used for computing the difference between the Gram matrix of the feature map from style image A and the Gram matrix of the feature map from the generated image G. – Don’t worry if you’re not yet familiar with Gram matrix, as I’ll talk about it later in the next section.

Figure 5. Equation for computing the style loss from a single feature map [1].
Figure 5. Equation for computing the style loss from a single feature map [1].

As we already got the idea to compute content and style loss, we can now combine them to form the total loss. You can see in Figure 6 that the summation between content and style loss is done with the weighting parameters α and β. These two coefficients allow us to control the emphasis of the loss function. So, if we want to emphasize the content, we can increase α, or if we want the style to be more dominant, we can use a higher value for β.

Figure 6. Equation for computing the total loss [1].
Figure 6. Equation for computing the total loss [1].

Later in the training phase, the weights of the VGG-based network will be frozen, which means that we will not train the model any further. Instead, the value from our loss function is going to be used to update the pixel values of the generated image. Thanks to this reason, the term "training" is actually not the most accurate way to describe this process since the network itself does not undergo training. A better term would be "optimization," since our goal is to optimize the generated image. – So, from now on, I will use the term "optimization" to refer to this process.

Gram Matrix

In the previous section I mentioned that the MSE computed for the style loss is done on the Gram matrices rather than the plain feature maps. The reason that we compute Gram matrix is because it is an effective way to extract information regarding the correlation between two or more channels within a feature map. Look at Figure 7 and 8 to see how a Gram matrix is constructed. In this illustration, I assume that our feature map has 8 channels, each having the spatial dimension of 4×4. The first thing we need to do is to flatten the spatial dimension and stack the channels vertically as shown below.

Figure 7. Flattening the spatial dimension [2].
Figure 7. Flattening the spatial dimension [2].

Afterwards, the resulting array will be multiplied by its transpose to construct the actual Gram matrix which has the size of C×C (in this case it’s 8×8). Such a matrix multiplication operation causes the feature map to lose its spatial information, but in return it captures the correlation between channels, representing textures and patterns that correspond to its artistic style. Hence, it should make a lot of sense now why we need to use Gram matrices for computing style loss.

Figure 8. Gram matrix is obtained by multiplying the spatially-flattened feature map with its transpose [2].
Figure 8. Gram matrix is obtained by multiplying the spatially-flattened feature map with its transpose [2].

Implementing NST from Scratch

As we have understood the underlying theorem behind NST, now that we will get our hands dirty with the code. The very first thing we need to do is to import all the required modules.

# Codeblock 1
import os
import torch

import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torchvision.models import VGG19_Weights
from torchvision.utils import save_image

These modules are pretty standard. I believe they should not confuse you especially if you have experience in training PyTorch models. But don’t worry if you’re not familiar with them yet, since you’ll definitely understand their use as we go.

Next, we are going to check whether our computer has a GPU installed. If it does, the code will automatically assign 'cuda' to the device variable. Even though this NST implementation can work without GPU, but I highly recommend you not to do that because performing NST optimization is computationally very expensive.

# Codeblock 2
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

Parameter Initialization

There are several parameters we need to configure for this optimization task, which the details can be seen in Codeblock 3 below.

# Codeblock 3
IMAGE_SIZE    = 224    #(1)
EPOCHS        = 20001  #(2)
LEARNING_RATE = 0.001  #(3)
ALPHA         = 1      #(4)
BETA          = 1000   #(5)

Here I set IMAGE_SIZE to 224 as shown at line #(1). The reason that I choose this number is simply because it matches with the original VGG input shape. In fact, it is technically possible to use larger size if you want your image to have higher resolution. However, keep in mind that it causes the optimization process to be longer.

Next, I set the EPOCHS to 20,001 (#(2)), – yes with that extra 1. – I do admit that this number is a bit strange, but it is actually just a technical detail that allows me to get the result at epoch 20,000. – Well, you’ll know it later. – One important thing to not about EPOCHS is that a higher number doesn’t necessarily mean a better result for everyone. This is essentially due to the nature of generative AI, where at some point it is just a matter of preference. Later in the optimization process, even though I use a large value for EPOCHS, I will save the generated image at certain intervals so that I can choose the result I like the most.

To the LEARNING_RATE (#(3)), 0.001 is basically just the number that I often use for this parameter. However, theoretically speaking, changing this number should affect the speed of the optimization process. Lastly for the ALPHA (#(4)) and BETA (#(5)), I configure them such that they have the ratio of 1/1000. It is mentioned in the paper that if we use smaller ratio (i.e., setting BETA to be even higher), it causes the artistic style too dominant, making the content of the image less visible. Look at Figure 9 below to see how different α/β ratios affect the resulting image.

Figure 9. The generated image created with different alpha/beta ratio [1]. The style image appears to be extremely dominant (leftmost) when the ratio is set to 1/100,000. Additionally, the artistic style is getting more complex as we move towards the deeper layer.
Figure 9. The generated image created with different alpha/beta ratio [1]. The style image appears to be extremely dominant (leftmost) when the ratio is set to 1/100,000. Additionally, the artistic style is getting more complex as we move towards the deeper layer.

Image Loading & Preprocessing

After the parameters have been initialized, now that we will continue with the image loading and preprocessing function. See the implementation in Codeblock 4 below.

# Codeblock 4
def load_image(filename):

    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),  #(1)
        transforms.ToTensor(),  #(2)
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  #(3)
                             std=[0.229, 0.224, 0.225])
    ])

    image = Image.open(filename)  #(4)
    image = transform(image)  #(5)
    image = image.unsqueeze(0)  #(6)

    return image

This function works by accepting the name of the image file to be loaded. Before actually loading the image, the first thing we do inside the function is to define the preprocessing steps using transforms.Compose(), which consists of resizing (#(1)), conversion to PyTorch tensor (#(2)), and normalization (#(3)). The normalization parameter I use here is obtained from the mean and the standard deviation of ImageNet, i.e., the dataset which the pretrained VGG-19 is trained on. By using the same configuration as this, we allow the pretrained model to work with its best performance.

The image itself is loaded using the Image.open() function from PIL (#(4)). Then, we directly preprocess it with the transformation steps we just defined (#(5)). Lastly, we apply the unsqueeze() method to create the batch dimension. Even though in this case we only have a single image in each batch, yet it is still necessary to add this dimension because PyTorch models are basically designed to process a batch of images.

Here we are going to use the picture of Victoria Library and the Starry Night painting. The two images in their unprocessed form are shown in Figure 10 below.

Figure 10. The picture of Victoria Library which I took back in 2015 (left) [2] will be used as both the content and generated image, while the Starry Night painting by Vincent van Gogh (right) [3] will act as the style image.
Figure 10. The picture of Victoria Library which I took back in 2015 (left) [2] will be used as both the content and generated image, while the Starry Night painting by Vincent van Gogh (right) [3] will act as the style image.

Now that we will load these images using the load_image() function we defined above. See Codeblock 5 for the details.

# Codeblock 5
content_image = load_image('Victoria Library.jpg').to(device)  #(1)
style_image = load_image('Starry Night.jpg').to(device)  #(2)
gen_image = content_image.clone().requires_grad_(True)  #(3)

Here I’m using the picture of Victoria Library as the content image (#(1)), while the painting will serve as the style image (#(2)). In this case, the same Victoria Library picture will also be used for the generated image (#(3)). As I mentioned earlier, it is possible to use random noise for it. However, I decided not to do so because based on my experiment I found that the information from the content image did not transfer properly to the generated image for some reasons. Here we also need to apply requires_grad_(True) to the generated image in order to allow its pixel values to be updated by backpropagation.

We can check if the images have been loaded and preprocessed properly by running the following code. You can see in the resulting output that both images now have the height of 224 pixels, which is exactly what we set earlier. The transforms.Resize() function automatically adjusts the width to maintain the aspect ratio, ensuring the images look proportional. Additionally, you may also notice that their colors become darker, which is caused by the normalization process.

# Codeblock 6
plt.imshow(content_image.permute(0, 2, 3, 1).squeeze().to('cpu'))
plt.show()

plt.imshow(style_image.permute(0, 2, 3, 1).squeeze().to('cpu'))
plt.show()
Figure 11. Both images have been successfully loaded and preprocessed (output from Codeblock 6) [2].
Figure 11. Both images have been successfully loaded and preprocessed (output from Codeblock 6) [2].

Modifying the VGG-19 Model

In PyTorch, the VGG-19 architecture can easily be loaded using models.vgg19(). Since we want to utilize its pretrained version, we need to pass VGG19_Weights.IMAGENET1K_V1 for the weights parameter. If this is your first time running the code, it will automatically start downloading the weights, which is around 550 MB.

# Codeblock 7
models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1)

Before we actually modify the architecture, I want you to see its complete version below.

# Codeblock 7 output
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace=True)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

I need to admit that the VGG-19 architecture I illustrated in Figure 2 is a bit oversimplified. However, the idea is actually the same, in a sense that we will take the output from _conv42 layer for the content image, and from _conv11, _conv21, _conv31, _conv41, and _conv51 for the style image. In the Codeblock 7 output above, _conv42 corresponds to layer number 21, whereas the five layers for the style image correspond to layer 0, 5, 10, 19, and 28, respectively. We are going to modify the pretrained model based on this requirement which I do in the ModifiedVGG() class shown below.

# Codeblock 8
class ModifiedVGG(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer_content_idx = [21]  #(1)
        self.layer_style_idx = [0, 5, 10, 19, 28]  #(2)

        #(3)
        self.model = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:29]

    def forward(self, x):

        content_features = []  #(4)
        style_features = []    #(5)

        for layer_idx, layer in enumerate(self.model):
            x = layer(x)  #(6)

            if layer_idx in self.layer_content_idx:
                content_features.append(x)  #(7)

            if layer_idx in self.layer_style_idx:
                style_features.append(x)  #(8)

        return content_features, style_features  #(9)

The first thing we do inside the class is to create the __init__() method. Here we specify the indices of the layers which the feature maps are going to be extracted from, as shown at line #(1) and #(2). The pretrained VGG-19 model itself is initialized at line #(3). Notice that here I use [:29] to take all layers from the beginning up to layer number 28 only. This is essentially done because flowing the tensors all the way to the end of the network is just necessary for this NST task.

Next, inside the forward() method we first allocate two lists, one for storing the feature maps from content image (#(4)) and another one for the feature maps from style image (#(5)). Since the VGG architecture only consists of sequential layers, we can do the forward propagation using a typical for loop. With this approach, the feature map from the previous layer will directly be fed into the subsequent one (#(6)). Both content_features (#(7)) and style_features (#(8)) lists will be appended with a feature map whenever their corresponding if statement returns True. It is worth noting that the if statement for the content image will only be called once since we only want to keep the feature map from layer 21. Despite this behavior, I implement it in a loop anyway for the sake of flexibility so that you can take the content feature maps from multiple layers if you want.

Both the content_features and style_features lists will be the return values of our forward() method (#(9)). Later on, if you feed the content image into the network, you can just take the first output. If you pass the style image into it, then you can take the second output. And you will need to take both outputs whenever you pass the generated image into the network.

Now we can check if our ModifiedVGG() class works properly by passing content_image and style_image through it. See the details in Codeblock 9 below.

# Codeblock 9
modified_vgg = ModifiedVGG().to(device).eval()  #(1)

content_features = modified_vgg(content_image)[0]  #(2)
style_features = modified_vgg(style_image)[1]  #(3)

print('content_features lengtht:', len(content_features))
print('style_features lengtht:', len(style_features))
# Codeblock 9 output
content_features length : 1
style_features length   : 5

The first thing we do in the above code is to initialize the model we just created (#(1)). Remember that since we won’t train the network any further, we need to freeze its weights using the eval() method. Next, we can now forward-propagate the content (#(2)) and the style image (#(3)). If we print out the number of elements of both outputs, we can see that content_features consists of only a single element whereas style_features contains 5 elements, in which every single of those corresponds to the feature map from the selected layers.

Just to make the underlying process clearer, I would like to display the feature maps stored in the two lists. To do so, there are some technical stuff you need to follow. – Well, this is actually something we need to do every time we want to display an image processed with PyTorch. – As seen in Codeblock 10, since PyTorch places the channel dimension of a tensor at the 1st axis, we need to swap it with the last axis using the permute() method in order to allow Matplotlib to display it. Next, we also need to use squeeze() to drop the batch dimension. Since the _conv42 layer implements 512 kernels, our content image is now represented as a feature map of 512 channels, each storing different information regarding the content of the image. For the sake of simplicity, I will only display the first 5 channels, which can be achieved using a simple indexing method.

# Codeblock 10
plt.imshow(content_features[0].permute(0, 2, 3, 1).squeeze()[:,:,0].to('cpu').detach())
plt.show()

plt.imshow(content_features[0].permute(0, 2, 3, 1).squeeze()[:,:,1].to('cpu').detach())
plt.show()

plt.imshow(content_features[0].permute(0, 2, 3, 1).squeeze()[:,:,2].to('cpu').detach())
plt.show()

plt.imshow(content_features[0].permute(0, 2, 3, 1).squeeze()[:,:,3].to('cpu').detach())
plt.show()

plt.imshow(content_features[0].permute(0, 2, 3, 1).squeeze()[:,:,4].to('cpu').detach())
plt.show()

And below is what the Victoria Library looks like after being processed by the VGG network from its input layer to the _conv42 layer. Even though these representations are abstract and may seem difficult to interpret visually, yet they contain important information that the network uses to reconstruct the content.

Figure 12. Visualization of the content image after being processed through all layers of the VGG-19 network up to conv4_2 layer, showing channels 0, 1, 2, 3, and 4, respectively from left to right (output from Codeblock 10) [2].
Figure 12. Visualization of the content image after being processed through all layers of the VGG-19 network up to conv4_2 layer, showing channels 0, 1, 2, 3, and 4, respectively from left to right (output from Codeblock 10) [2].

With the same mechanism, we can also display the style image after being processed from the input layers up to the five selected layers. If you check the original VGG paper [4], you will see that the feature maps produced by _conv11, _conv21, _conv31, _conv41 and _conv51 are 64, 128, 256, 512, and 512, respectively. In the code below, I arbitrarily pick one channel from each feature map to be displayed.

# Codeblock 11
plt.imshow(style_features[0].permute(0, 2, 3, 1).squeeze()[:,:,60].to('cpu').detach())
plt.show()

plt.imshow(style_features[1].permute(0, 2, 3, 1).squeeze()[:,:,12].to('cpu').detach())
plt.show()

plt.imshow(style_features[2].permute(0, 2, 3, 1).squeeze()[:,:,71].to('cpu').detach())
plt.show()

plt.imshow(style_features[3].permute(0, 2, 3, 1).squeeze()[:,:,152].to('cpu').detach())
plt.show()

plt.imshow(style_features[4].permute(0, 2, 3, 1).squeeze()[:,:,76].to('cpu').detach())
plt.show()

You can see in the resulting output that the style image appears very clear in the initial layers, indicating that the feature maps from these layers are useful to preserve style information. However, it is also worth to note that taking the style information from deeper layers is also important in order to preserve higher-order artistic style. This notion is actually proven by Figure 9, where the artistic style appears to be more complex at layer _conv31 than at layer _conv11.

Figure 13. The style image after being processed sequentially through the VGG-19 network up to conv1_1, conv2_1, conv3_1, conv4_1, and conv5_1 layers, respectively from left to right (output from Codeblock 11) [2].
Figure 13. The style image after being processed sequentially through the VGG-19 network up to conv1_1, conv2_1, conv3_1, conv4_1, and conv5_1 layers, respectively from left to right (output from Codeblock 11) [2].

Creating the Gram Matrix

Both the feature maps from style and generated image will be converted to Gram matrices before the loss is computed using MSE. The Gram matrix computation previously illustrated in Figure 7 and 8 is implemented in the compute_gram_matrix() function below. The way this function works is pretty straightforward. It first flattens the spatial dimension (#(1)), then the resulting tensor is matrix-multiplied with its transpose (#(2)).

# Codeblock 12
def compute_gram_matrix(feature_map):
    batch_size, num_channels, height, width = feature_map.shape

    feature_map_flat = feature_map.view(num_channels, height*width)  #(1) 
    gram_matrix = torch.matmul(feature_map_flat, feature_map_flat.t())  #(2)

    return gram_matrix

Now I am going to actually apply this function to compute the Gram matrix of the style image feature maps that we stored earlier in style_features list. Additionally, I will also visualize them so that you can have a better understanding about this matrix. Look at the Codeblock 13 below to see how I do it.

# Codeblock 13
style_features_0 = compute_gram_matrix(style_features[0])
style_features_1 = compute_gram_matrix(style_features[1])
style_features_2 = compute_gram_matrix(style_features[2])
style_features_3 = compute_gram_matrix(style_features[3])
style_features_4 = compute_gram_matrix(style_features[4])

plt.imshow(style_features_0.to('cpu').detach())
plt.show()

plt.imshow(style_features_1.to('cpu').detach())
plt.show()

plt.imshow(style_features_2.to('cpu').detach())
plt.show()

plt.imshow(style_features_3.to('cpu').detach())
plt.show()

plt.imshow(style_features_4.to('cpu').detach())
plt.show()
Figure 14. The Gram matrices of different feature maps from style image (output from Codeblock 13) [2].
Figure 14. The Gram matrices of different feature maps from style image (output from Codeblock 13) [2].

The output shown in Figure 14 aligns with the illustration in Figure 8, where the size of each matrix matches with the number of channels in the corresponding feature map. The colors inside these matrices themselves indicates the correlation scores between two channels, in which higher value is represented by lighter colors. There is actually not much we can interpret from these matrices. However, keep in mind that they contain the style information within an image. The only thing we can see here is the subtle diagonal line spanning from top left all the way to the bottom right. This pattern makes sense because the correlation between a channel and itself (the diagonal elements) is typically higher than the correlation between different channels (the off-diagonal elements).

Loss Function & Optimizer

The pixel intensity values of the generated image will be updated based on the weighted sum of content and style loss. As I’ve mentioned earlier, these two loss functions are actually the same: the Mean Squared Error. Due to this reason, we don’t need to create separate functions for them. Meanwhile to the optimizer, there are many sources out there suggesting that we should use L-BFGS optimizer for NST. However, I didn’t find any explicit information about it in the paper. So, I think it’s completely fine for us to go with any optimizers. And in this case, I will just use Adam.

In the following codeblock, I impleement the MSE loss from scratch and initialize the Adam optimizer taken from the PyTorch module. One thing that you need to pay attention to is that we need to pass our generated image to the params parameter, not the weights of the model. This way, each optimization step will update the pixel values of the gen_image while keeping the model weights unchanged.

# Codeblock 14
def MSE(tensor_0, tensor_1):
    return torch.mean((tensor_0-tensor_1)**2)

optimizer = optim.Adam(params=[gen_image], lr=LEARNING_RATE)

Denormalization Function

If we go back to Figure 11, you will notice that the coloration of the content and the style image became strange after being normalized. Hence, it is necessary for us to apply the so-called denormalization process on the resulting generated image so that the color returns to its original state. We implement this mechanism inside the denormalize() function below. The mean (#(1)) and std (#(2)) parameters are the same values used in the normalization process in Codeblock 4. Using these two values, we apply the operation at line (#(3)), which rescales the pixel values from being centered around 0 back to their original range.

# Codeblock 15
def denormalize(gen_image):
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)  #(1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)  #(2)

    gen_image = gen_image*std + mean  #(3)

    return gen_image

The NST Optimization

As we already got all the necessary components prepared, now that we will compile them into a single function which I name optimize(). See the Codeblock 16a, 16b and 16c below for the details.

# Codeblock 16a
def optimize():

    #(1)
    content_losses = []
    style_losses = []
    total_losses = []

    for epoch in tqdm(range(EPOCHS)):
        content_features = modified_vgg(content_image)[0]  #(2)
        style_features = modified_vgg(style_image)[1]  #(3)

        gen_features = modified_vgg(gen_image)
        gen_features_content, gen_features_style = gen_features  #(4)

This function initially works by allocating 3 empty lists, each for keeping track of the content, style and total loss (#(1)). In each epoch, we pass the content, style and the generated image through the modified VGG network we created earlier. Remember that for the content image, we only extract the content features (#(2)), while for the style image, we take its style features only (#(3)). This is basically the reason that I use the indexer of [0] and [1] for the two features, respectively. As for the generated image, we need both its content and style features, so we store them separately in gen_features_content and gen_features_style (#(4)).

Previously I mentioned that our three input images are processed simultaneously. However, in the above code I feed them one by one instead. Don’t worry about such a difference in the implementation because it’s only the matter of technical stuff. I actually do this just for the sake of simplicity so you can better understand the entire NST optimization algorithm.

# Codeblock 16b
        content_loss = 0  #(1)
        style_loss = 0  #(2)

        for content_feature, gen_feature_content in zip(content_features, gen_features_content):
            content_loss += MSE(content_feature, gen_feature_content)  #(3)

        for style_feature, gen_feature_style in zip(style_features, gen_features_style):

            style_gram = compute_gram_matrix(style_feature)  #(4)
            gen_gram = compute_gram_matrix(gen_feature_style)  #(5)

            style_loss += MSE(style_gram, gen_gram)  #(6)

        total_loss = ALPHA*content_loss + BETA*style_loss  #(7)

Still inside the same loop, we set the content and style loss to 0 as shown at line #(1) and #(2) in Codeblock 16b. Afterwards, we iterate through all the content features of the content image and the generated image to calculate the MSE (#(3)). Again, I want to remind you that this loop will only iterate once. We create the similar loop for the style features, where in this case we compute the Gram matrix of each style feature from both the style image (#(4)) and the generated image (#(5)) before computing the MSE (#(6)) and accumulating it in the style_loss. After content_loss and style_loss are obtained, we then give them weightings with the ALPHA and BETA coefficients which we previously set to 1 and 1000.

The optimize() function hasn’t finished yet. We will continue it with the Codeblock 16c below. In fact, the following code only implements the standard procedure for training PyTorch models. Here, we use the zero_grad() method to clear the gradients tracked by the optimizer (#(1)) before computing the new ones for the current epoch (#(2)). Then, we update the trainable parameters based on the gradient value using the step() method (#(3)), where in our case these trainable parameters refer to the pixel intensities in the generated image.

# Codeblock 16c
        optimizer.zero_grad()  #(1)
        total_loss.backward()  #(2)
        optimizer.step()  #(3)

        #(4)
        content_losses.append(content_loss.item())
        style_losses.append(style_loss.item())
        total_losses.append(total_loss.item())

        #(5)
        if epoch % 200 == 0:
            gen_denormalized = denormalize(gen_image)
            save_image(gen_denormalized, f'gen_image{epoch}.png')

    return content_losses, style_losses, total_losses

Afterwards, we append all loss values we obtained in the current epoch to the lists we initialized earlier (#(4)). This step is not mandatory, but I do it anyway since I want to display how our loss values change as we iterate through the optimization process. Finally, we denormalize and save the generated image every 200 epochs so that we can choose the result we prefer the most (#(5)).

As the optimization function is completed, we will now run it using the code below. Here I store the loss values in the content_losses, style_losses and total_losses lists. Sit back and relax while the GPU blends the content and style images. In my case, I am using Kaggle Notebook with Nvidia P100 GPU enabled, and it takes around 15 minutes to complete the 20,001 optimization steps.

# Codeblock 17
losses = optimize()
content_losses, style_losses, total_losses = losses

Finally, after the process is done, we successfully got the Victoria Library picture redrawn with the style of Van Gogh’s Starry Night painting. You can see in the following figure that the effect from the style image becomes more apparent in later epochs.

Figure 15. The NST optimization result at epoch 0, 5000, 10000, 15000 and 20000, respectively from left to right [2].
Figure 15. The NST optimization result at epoch 0, 5000, 10000, 15000 and 20000, respectively from left to right [2].

Talking about the training progress in Figure 16, the vertical axis of the three plots represents the loss value, whereas the horizontal axis denotes the epoch. – And well, you might notice something unusual here. – When we train a Deep Learning model, we typically have our loss decreases as the training progresses. And this is indeed the case for the style and total loss. However, what makes things strange is that the content loss having an increasing loss value instead.

Such a phenomenon occurs because our generated image was initialized with the clone of the content image, which means that our initial content loss is 0. As the training progresses, the artistic style from the style image is gradually infused to the generated image, causing the style loss to decrease yet in return makes the content loss to increase. This absolutely makes sense because the generated image slowly evolves from the content image. Theoretically speaking, if we initialize the generated image with random noise, we can expect high value for both the initial content and style loss before eventually decreasing in the subsequent epochs.

# Codeblock 18
plt.title('content_losses')
plt.plot(content_losses)
plt.show()

plt.title('style_losses')
plt.plot(style_losses)
plt.show()

plt.title('total_losses')
plt.plot(total_losses)
plt.show()
Figure 16. What the training progress looks like when the generated image is initialized to be the same as the content image.
Figure 16. What the training progress looks like when the generated image is initialized to be the same as the content image.

That’s pretty much everything I can explain you about the theory and the implementation of NST. Feel free to comment if you have any thoughts about this article. Thanks for reading, and have a nice day!

_P.S. You can find the code used in this article in my GitHub repo as well. Here’s the link to it._


References

[1] Leon A. Gatys, Alexander S. Ecker, Matthias Bethge. A Neural Algorithm of Artistic Style. Arxiv. https://arxiv.org/pdf/1508.06576 [Accessed October 6, 2024].

[2] Image created originally by author.

[3] Van Gogh – Starry Night – Google Art Project. Wikimedia Commons. https://commons.wikimedia.org/wiki/File:VanGogh-_StarryNight-_Google_Art_Project.jpg [Accessed October 7, 2024].

[4] Karen Simonyan, Andrew Zisserman. Very Deep Convolutional Networks for Large-Scale Image Recognition. Arxiv. https://arxiv.org/pdf/1409.1556 [Accessed October 11, 2024].

The post Paper Walkthrough: Neural Style Transfer appeared first on Towards Data Science.

]]>
Paper Walkthrough: Attention Is All You Need https://towardsdatascience.com/paper-walkthrough-attention-is-all-you-need-80399cdc59e1/ Sun, 03 Nov 2024 12:02:16 +0000 https://towardsdatascience.com/paper-walkthrough-attention-is-all-you-need-80399cdc59e1/ The complete guide to implementing a Transformer from scratch

The post Paper Walkthrough: Attention Is All You Need appeared first on Towards Data Science.

]]>
Introduction

As the title suggests, in this article I am going to implement the Transformer architecture from scratch with PyTorch – yes, literally from scratch. Before we get into it, let me provide a brief overview of the architecture. Transformer was first introduced in a paper titled "Attention Is All You Need" written by Vaswani et al. back in 2017 [1]. This neural network model is designed to perform seq2seq (Sequence-to-Sequence) tasks, where it accepts a sequence as the input and is expected to return another sequence for the output such as machine translation and question answering.

Before Transformer was introduced, we usually used RNN-based models like LSTM or GRU to accomplish seq2seq tasks. These models are indeed capable of capturing context, yet they do so in a sequential manner. This approach makes it challenging to capture long-range dependencies, especially when the important context is very far behind the current timestep. In contrast, Transformer can freely attend any parts of the sequence that it considers important without being constrained by sequential processing.

Transformer Components

The main Transformer architecture can be seen in the Figure 1 below. It might look a bit intimidating at first, but don’t worry – I am going to explain the entire implementation as complete as possible.

Figure 1. The Transformer architecture [1].
Figure 1. The Transformer architecture [1].

You can see in the figure that Transformer comprises of many components. The large block on the left is called Encoder, while the one on the right is called Decoder. In the case of machine translation, for example, the Encoder is responsible for capturing the pattern of the original sentence, whereas the Decoder is employed to generate the corresponding translation.

The ability of the Transformer to freely attend to specific words is due to the presence of the Multihead Attention block, in which it works by comparing each word with every other word within the sequence. It is important to note that the three Multihead Attention blocks (highlighted in orange) are not exactly the same despite their similar purpose. Nevertheless, while the attention mechanism captures the relationships between words, it does not account for the sequence of the words itself, which is actually very crucial in NLP. Thus, to retain sequence information, we employ the so-called Positional Encoding.

I think the remaining components of the network are pretty straighforward: Add & Norm block (colored in yellow) is basically an addition followed by normalization operation, Feed Forward (blue) is just a linear layer, Input & Output Embedding (red) are used to convert input words into vectors, Linear block after the Decoder (purple) is another standard linear layer, and Softmax (green) is the layer responsible for generating a probability distribution over the vocabulary to predict the next word.


Imports and Configurations

Now let’s actually start coding by importing the required modules: the base torch module for basic functionalities, the nn submodule for initializing neural network layers, and the summary() function from torchinfo which I will use to display the details of the entire Deep Learning model.

# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

Afterwards, I am going to initialize the parameters for the Transformer model. The first parameter is SEQ_LENGTH which the value is set to 200 (marked with #(1) in Codeblock 2 below). This is essentially done because we want the model to capture a sequence of exactly 200 tokens. If the sequence is longer, it will be truncated. Meanwhile, if it has fewer than 200 tokens, padding will be applied. By the way the term token itself does not necessarily correspond to a single word, as each word can actually be broken down into several tokens. However, we will not talk about these kinds of preprocessing details here, as the main goal of this article is to implement the architectural design. In this particular case we assume that the sequence has already been preprocessed and is ready to be fed into the network. The subsequent parameters are VOCAB_SIZE_SRC (#(2)) and VOCAB_SIZE_DST (#(3)), in which the former denotes the number of unique tokens possible to appear in the original sequence, while the latter is the same thing but for the translated sequence. It is worth noting that the numbers for these parameters are chosen arbitrarily. In practice, sequence lengths can range from a few hundred to several thousand tokens, while the vocabulary size typically range from tens of thousands to a few hundred thousand tokens.

# Codeblock 2
SEQ_LENGTH     = 200    #(1)
VOCAB_SIZE_SRC = 100    #(2)
VOCAB_SIZE_DST = 120    #(3)

BATCH_SIZE = 1     #(4)
D_MODEL    = 512   #(5)
NUM_HEADS  = 8     #(6)
HEAD_DIM   = D_MODEL//NUM_HEADS    # 512 // 8 = 64    #(7)
HIDDEN_DIM = 2048  #(8)
N          = 6     #(9)
DROP_PROB  = 0.1   #(10)

Still with Codeblock 2, here I set the BATCH_SIZE to 1 (#(4)). You can actually use any number for the batch size since it does not affect the model architecture at all. The D_MODEL and NUM_HEADS parameters on the other hand, are something that you cannot choose arbitrarily, in a sense that D_MODEL (#(5)) needs to be divisible by NUM_HEADS (#(6)). The D_MODEL itself corresponds to the model dimension, which is actually also equivalent to the embedding dimension. This notion implies that every single token is going to be represented as a vector of size 512. Meanwhile, NUM_HEADS=8 means that there will be 8 heads inside a Multihead Attention layer. Later on, the 512 features of each token will be spread evenly into these 8 attention heads, so every single head will be responsible for handling 64 features (HEAD_DIM) as marked at line #(7). The HIDDEN_DIM parameter, which the value is set to 2048 (#(8)), denotes the number of neurons in the hidden layer inside the Feed Forward blocks. Next, if you go back to Figure 1, you will notice that there is a symbol next to the Encoder and the Decoder which essentially means that we can stack them N times. In this case, we set it to 6 as marked at line #(9). Lastly, we can also control the rate of the dropout layers through the DROP_PROB parameter (#(10)).

In fact, all the parameter values I set above are taken from the base configuration of the Transformer model shown in the figure below.

Figure 2. The Transformer configuration we will implement (highlighted in green) [1].
Figure 2. The Transformer configuration we will implement (highlighted in green) [1].

Input & Output Embedding

As all parameters have been initialized, we will jump into the first component: the Input and Output Embedding. The purpose of the two are basically the same, namely to convert each token in the sequence into its corresponding 512 (D_MODEL)-dimensional vector representation. What makes them different is that Input Embedding processes the tokens from the original sentence, whereas Output Embedding does the same thing for the translated sentence. The Codeblock 3 below shows how I implement them.

# Codeblock 3
class InputEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE_SRC,  #(1) 
                                      embedding_dim=D_MODEL)

    def forward(self, x):
        print(f"originalt: {x.shape}")

        x = self.embedding(x)  #(2)
        print(f"after embeddingt: {x.shape}")

        return x

class OutputEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE_DST,  #(3)
                                      embedding_dim=D_MODEL)

    def forward(self, x):
        print(f"originalt: {x.shape}")

        x = self.embedding(x)  #(4)
        print(f"after embeddingt: {x.shape}")

        return x

The InputEmbedding() and OutputEmbedding() classes above appear to be identical. However, if you take a closer look at the nn.Embedding() layer from the two classes (at the line marked with #(1) and #(3)), you will see that in InputEmbedding() I set the num_embeddings parameter to VOCAB_SIZE_SRC (100), while in OutputEmbedding() we set it to VOCAB_SIZE_DST (120). This approach allows us to handle two languages that have different vocabulary sizes, where in this case we assume that the source language and the destination language have the number of unique tokens of 100 and 120, respectively. Next, the forward() method of the two classes is completely the same, in which it works by accepting a sequence of tokens and return the result produced by the self.embedding() layer (#(2) and #(4)). Here I also print out the dimension of the tensor before and after processing so you can better understand how the tensors are actually processed.

To check whether our code is working properly, we can test it by passing a dummy tensor through the network. In the Codeblock 4 below, we first initialize the InputEmbedding() layer (#(1)) followed by a batch of single-dimensional array (#(2)). This array is generated using torch.randint(), which I configure to produce a sequence of random integers ranging from 0 to VOCAB_SIZE_SRC (100) with the length of SEQ_LENGTH (200). Afterwards, we can just pass the x_src tensor through the input_embedding layer (#(3)).

# Codeblock 4
input_embedding = InputEmbedding()  #(1)

x_src = torch.randint(0, VOCAB_SIZE_SRC, (BATCH_SIZE, SEQ_LENGTH))  #(2)
x_src = input_embedding(x_src)  #(3)

You can see in the output that the sequence which initially has the length of 200 now becomes 200×512. This indicates that our InputEmbedding() class successfully converted a sequence of 200 tokens into a sequence of 200 vectors with 512 dimensions each.

# Codeblock 4 output
original        : torch.Size([1, 200])
after embedding : torch.Size([1, 200, 512])

We can also test our OutputEmbedding() class in the exact same way as shown in Codeblock 5.

# Codeblock 5
output_embedding = OutputEmbedding()

x_dst = torch.randint(0, VOCAB_SIZE_DST, (BATCH_SIZE, SEQ_LENGTH))
x_dst = output_embedding(x_dst)
# Codeblock 5 output
original        : torch.Size([1, 200])
after embedding : torch.Size([1, 200, 512])

In addition to the Output Embedding layer, you can see back in Figure 1 that it accepts the shifted right outputs as its input. This basically means that the current token in the original sentence corresponds to the next token in the translated sentence (i.e., the token at the subsequent timestep). This shifting is necessary to be done because the first position in the translated sentence is reserved for the so-called start token, which signals to the network that it is the beginning of the sentence to generate. However – as I have mentioned earlier, we are not going to get deeper into such a preprocessing step. Here we assume that the x_dst tensor passed through the output_embedding layer in the Codeblock 5 above already includes the start token. See the Figure 3 below to better understand this idea. In this example, the sequence on the left is a sentence in English, and the sequence on the right is the corresponding shifted-right output in Indonesian.

Figure 3. An illustration of the shifted-right output sequence [2].
Figure 3. An illustration of the shifted-right output sequence [2].

Positional Encoding

As raw token sequence has been processed with Input and Output Embedding layers, we are going to inject positional encoding into them. The plus symbol in the architecture indicates that the operation is done by performing element-wise addition between the positional encoding values and the tensor produced by Input and Output Embedding. See the zoomed-in version of the Transformer model in Figure 4 below.

Figure 4. Positional encoding is injected to the Input and Output Embedding by element-wise addition [1].
Figure 4. Positional encoding is injected to the Input and Output Embedding by element-wise addition [1].

According to the original paper, positional encoding is defined by the following equation, where pos is the current position in the sequence axis and i is the index of the element in the 512 (D_MODEL)-dimensional token vector.

Figure 5. The equation to create positional encoding [1].
Figure 5. The equation to create positional encoding [1].

The above equation looks scary at glance, but the idea is actually pretty simple. For each embedding dimension (the D_MODEL-dimensional vector), we create a sequence of numbers ranging from -1 to 1, following a sine and cosine wave patterns along the sequence axis. The illustration for this is shown in Figure 6 below.

Figure 6. Positional encoding Illustration [2].
Figure 6. Positional encoding Illustration [2].

The lines drawn in orange indicate sine waves, while the ones in green are cosine waves. The wave value that lies at a token in a specific embedding dimension is going to be taken and summed up with the corresponding embedding tensor value. Furthermore, notice that the embed dim of even numbers (0, 2, 4, …) as well as the embed dim of odd numbers (1, 3, 5, …) are using sine and cosine patterns alternately with a decreasing frequency as we move from left to right across the embedding dimensions. By doing all these things, we allow the model to preserve information regarding the position of all tokens.

The implementation of this concept is done in the PositionalEncoding() class which you can see in Codeblock 6.

# Codeblock 6
class PositionalEncoding(nn.Module):

    def forward(self):
        pos = torch.arange(SEQ_LENGTH).reshape(SEQ_LENGTH, 1)  #(1)
        print(f"postt: {pos.shape}")

        i = torch.arange(0, D_MODEL, 2)  #(2)
        denominator = torch.pow(10000, i/D_MODEL)  #(3)
        print(f"denominatort: {denominator.shape}")

        even_pos_embed = torch.sin(pos/denominator)  #(4)
        odd_pos_embed  = torch.cos(pos/denominator)  #(5)
        print(f"even_pos_embedt: {even_pos_embed.shape}")

        stacked = torch.stack([even_pos_embed, odd_pos_embed], dim=2)  #(6)
        print(f"stackedtt: {stacked.shape}")

        pos_embed = torch.flatten(stacked, start_dim=1, end_dim=2)  #(7)
        print(f"pos_embedt: {pos_embed.shape}")

        return pos_embed

The above code might seem somewhat unusual since I directly go with the forward() method – omitting the __init__() method, which is typically included when working with Python classes. This is essentially done because there is no neural network layer need to be instantiated when a PositionalEncoding() object is initialized. The configuration parameters to be used themselves are defined as global variables (i.e., SEQ_LENGTH and D_MODEL), thus can directly be used inside the forward() method.

All processes done in the forward pass encapsulates the equation shown in Figure 5. The pos variable I create at line #(1) corresponds to the same thing in the equation, in which it is essentially just a sequence of numbers from 0 to SEQ_LENGTH (200). I want these numbers to span along the SEQ_LENGTH axis in the embedding tensor, so I add a new axis using the reshape() method. Next, at line #(2) I initialize the i array with values ranging from 0 to D_MODEL (512) with the step of 2. Hence, there will only be 256 numbers generated. The reason that I do this is because in the subsequent step I want to use the i array twice: one for the even embedding dimension and another one for the odd embedding dimension. However, the i array itself is not going to be used for the two directly, rather, we will employ it to compute the entire denominator in the equation (#(3)) before eventually being used for creating the sine (#(4)) and cosine waves (#(5)). At this point we have already had two positional embedding tensors: even_pos_embed and odd_pos_embed. What we are going to do next is to combine them such that the resulting tensor will have alternating sine and cosine pattern as shown back in Figure 6. This can be achieved using a little trick that I do at line #(6) and #(7).

Next, we will run the following code to test if our PositionalEncoding() class works properly.

# Codeblock 7
positional_encoding = PositionalEncoding()

positional_embedding = positional_encoding()
# Codeblock 7 output
pos            : torch.Size([200, 1])
denominator    : torch.Size([256])
even_pos_embed : torch.Size([200, 256])    #(1)
odd_pos_embed  : torch.Size([200, 256])    #(2)
stacked        : torch.Size([200, 256, 2])
pos_embed      : torch.Size([200, 512])    #(3)

Here I print out every single step in the forward() function so that you can see what is actually going on under the hood. The main idea of this process is that, once you get the even_pos_embed (#(1)) and odd_pos_embed (#(2)) tensors, what you need to do afterwards is to merge them such that the resulting dimension becomes 200×512 as shown at line #(3) in the Codeblock 7 output. This dimension exactly matches with the size of the embedding tensor we discussed in the previous section (SEQ_LENGTH×D_MODEL), allowing element-wise addition to be performed.


The Attention Mechanism

The three Multihead Attention blocks highlighted in orange in Figure 1 shares the same basic concept, hence they all have the same structure. The following figure shows what the components inside a single Multihead Attention block look like.

Figure 7. The components inside a Multi-Head Attention block [1].
Figure 7. The components inside a Multi-Head Attention block [1].

The Scaled Dot-Product Attention block (purple) in the above figure itself also comprises of several other components which you can see in the illustration below.

Figure 8. The Scaled Dot-Product Attention block [1].
Figure 8. The Scaled Dot-Product Attention block [1].
Figure 9. The formal mathematical expression of the operations in Figure 8 [1].
Figure 9. The formal mathematical expression of the operations in Figure 8 [1].

Let’s now break all these things down one by one.

Scaled Dot-Product Attention

I am going to start with the Scaled Dot-Product Attention block first. In Codeblock 8 implement it inside the Attention() class.

 # Codeblock 8
class Attention(nn.Module):

    def create_mask(self):  #(1)
        mask = torch.tril(torch.ones((SEQ_LENGTH, SEQ_LENGTH)))  #(2)
        mask[mask == 0] = -float('inf')
        mask[mask == 1] = 0
        return mask.clone().detach()

    def forward(self, q, k, v, look_ahead_mask=False):  #(3)
        print(f"qttt: {q.shape}")
        print(f"kttt: {k.shape}")
        print(f"vttt: {v.shape}")

        multiplied = torch.matmul(q, k.transpose(-1,-2))  #(4)
        print(f"multipliedtt: {multiplied.shape}")

        scaled = multiplied / torch.sqrt(torch.tensor(HEAD_DIM))  #(5)
        print(f"scaledttt: {scaled.shape}")

        if look_ahead_mask == True:  #(6)
            mask = self.create_mask()
            print(f"maskttt: {mask.shape}")
            scaled += mask  #(7)

        attn_output_weights = torch.softmax(scaled, dim=-1)  #(8)
        print(f"attn_output_weightst: {attn_output_weights.shape}")

        attn_output = torch.matmul(attn_output_weights, v)  #(9)
        print(f"attn_outputtt: {attn_output.shape}")

        return attn_output, attn_output_weights  #(10)

Similar to the PositionalEncoding() class in the previous section, here I also omit the __init__() method since there are no neural network layers to be implemented. If you take a look at Figure 8, you will see that this block only comprises standard mathematical operations.

The Attention() class initially works by capturing four inputs: query (q), key (k), value (v), and a boolean parameter of look_ahead_mask as written at line #(3). The query, key and value are basically three different tensors yet having the exact same shape. In this case, their dimensions are all 200×64, where 200 is the sequence length while 64 is the head dimension. Remember that the value of 64 for the HEAD_DIM is obtained by dividing D_MODEL (512) by NUM_HEADS (8). Based on this notion, now you know that the Attention() class implemented here basically contains the operations done within every single of the 8 attention heads.

The first process to be done inside the Scaled Dot-Product Attention block is matrix multiplication between query and key (#(4)). Remember that we need to transpose the key matrix so that its dimension becomes 64×200, allowing it to be multiplied with the query which the dimension is 200×64. The idea behind this multiplication is to compute the relationship between each token and all other tokens. The output of this multiplication operation is commonly known as unnormalized attention scores or attention logits, where the variance of the elements inside this tensor is still high. To scale these values, we divide the tensor by the square root of the head dimension (√64), resulting in the scaled attention scores (#(5)). The actual attention weights tensor is then obtained after we pass it through a softmax function (#(8)). Lastly, this attention weights tensor is then multiplied with value (#(9)). – And here is where the magic happens: the v tensor, which is initially just a sequence of 64-dimensional token vectors, now becomes context-aware. This essentially means that each token vector is now enriched with information about the relationships between tokens, leading to a better understanding of the entire sentence. Finally, this forward() method will return both the context-aware token sequence (attn_output) and the attention weights (attn_output_weights) as written at line #(10).

Look-Ahead Mask

One thing that I haven’t explained regarding the Codeblock 8 above is the create_mask() function (#(1)). The purpose of this function is to generate the so-called look-ahead mask, which is used such that the model won’t be able to attend the subsequent words – hence the name look-ahead. This mask will later be implemented inside the first Multihead Attention block in the Decoder (the Masked Multi-Head Attention block, see Figure 1). The look-ahead mask itself is basically a square matrix with the height and width of SEQ_LENGTH (200) as written at line #(2). Since it is not feasible to draw a 200×200 matrix, here I give you an illustration of the same thing for a sequence of 7 tokens only.

Figure 10. An illustration of a look-ahead mask for a sentence with the sequence length of 7 [2].
Figure 10. An illustration of a look-ahead mask for a sentence with the sequence length of 7 [2].

As you can see the above figure, the look-ahead mask is essentially a triangular matrix, in which its lower part is filled with zeros, while the upper part is filled with -inf (negative infinity). At this point you need to remember the property of a softmax function: a very small value passed through it will be mapped to 0. Based on this fact, we can think of these -inf values as a mask which won’t allow any information to get passed through since it will eventually cause the weight matrix to pay zero attention to the corresponding token. By using this matrix, we essentially force a token to only pay attention to itself and to the previous tokens. For example, token 3 (from the Query axis) can only pay attention to token 3, 2, 1 and 0 (from the Key axis). This technique is very effective to be used during the training phase to ensure that the Decoder doesn’t rely on future tokens as they are unavailable during the inference phase (since tokens will be generated one by one).

Talking about the implementation, the create_mask() function will only be called whenever the look_ahead_mask parameter is set to True (#(6) in Codeblock 8). Afterwards, the resulting mask is applied to the scaled attention scores tensor (scaled) by element-wise addition (#(7)). With this operation, any numbers in the scaled tensor summed with 0 will remain unchanged, whereas the numbers masked with -inf will also become -inf, causing the output after being softmax-ed to be 0.

As always, to check whether our Scaled Dot-Product Attention mechanism and the masking process are working properly, we can run the following codeblock.

# Codeblock 9
attention = Attention()

q = torch.randn(BATCH_SIZE, SEQ_LENGTH, HEAD_DIM)
k = torch.randn(BATCH_SIZE, SEQ_LENGTH, HEAD_DIM)
v = torch.randn(BATCH_SIZE, SEQ_LENGTH, HEAD_DIM)

attn_output, attn_output_weights = attention(q, k, v, look_ahead_mask=True)
# Codeblock 9 output
q                    : torch.Size([1, 200, 64])
k                    : torch.Size([1, 200, 64])
v                    : torch.Size([1, 200, 64])
multiplied           : torch.Size([1, 200, 200])  #(1)
scaled               : torch.Size([1, 200, 200])
mask                 : torch.Size([200, 200])
attn_output_weights  : torch.Size([1, 200, 200])
attn_output          : torch.Size([1, 200, 64])   #(2)

In the Codeblock 9 output above, we can see that the multiplication between q (200×64) and transposed k (64×200) results in a tensor of size 200×200 (#(1)). The scaling operation, mask application, and the processing with softmax function do not alter this dimension. The tensor eventually changes back to the original q, k, and v size (200×64) after we perform matrix multiplication between attn_output_weights (200×200) and v (200×64), with the result now stored in attn_output variable (#(2)).


Multihead Self-Attention

The Scaled Dot-Product Attention mechanism we just discussed is actually the core of a Multihead Attention layer. In this section, we are going to discuss how to implement it inside the so-called Multihead Self-Attention Layer. The reason that it is named self is essentially because the query, key and value to be fed into are all derived from the same sequence. The two Attention blocks in the Transformer architecture that implement the Self-Attention mechanism can be seen in Figure 11. We can see here that the three arrows coming into both blocks are from the same source.

Figure 11. The Multihead Self-Attention block in the Encoder (left) and the Masked Multihead Self-Attention block in the Decoder (right) [1].
Figure 11. The Multihead Self-Attention block in the Encoder (left) and the Masked Multihead Self-Attention block in the Decoder (right) [1].

Generally speaking, the objective of a Self-Attention layer is to capture the context (relationship between words) from the same sequence. In the case of machine translation, Self-Attention block in the Encoder (left) is responsible to do so for the sentence in the original language, whereas the one in the Decoder (right) is for the sentence in the destination language. Previously I’ve mentioned that we need to implement look-ahead mask to the first Attention block in the Decoder. This is essentially because later in the inference phase, the Decoder will work by returning a single word at a time. Hence, during the training phase, the mask will prevent the model from attending to subsequent words. In contrast, the Encoder accepts the entire sequence at once both in the training and inference phase. Thus, we should not apply the look-ahead mask here since we want the model to capture the context based on the entire sentence, not only based on the previous and current tokens.

Look at the Codeblock 10 below to see how I implement the Self-Attention block. Remember that it is actually created based on the diagram in Figure 7.

# Codeblock 10
class SelfAttention(nn.Module):

    def __init__(self, look_ahead_mask=False):  #(1)
        super().__init__()
        self.look_ahead_mask = look_ahead_mask

        self.qkv_linear = nn.Linear(D_MODEL, 3*D_MODEL)  #(2)
        self.attention = Attention()  #(3)
        self.linear = nn.Linear(D_MODEL, D_MODEL)  #(4)

I want the SelfAttention() class above to be flexible, so that we can use it either with or without a mask. To do so, I define the look_ahead_mask parameter which by default is set to False (#(1)). Next, there will be two linear layers in this class. The first one is going to be placed before the Scaled Dot-Product Attention operation (#(2)), and the second one is placed after it (#(4)). Notice that the first linear layer (self.qkv_linear) is set to accept an input tensor of size D_MODEL (512) and return another tensor having the size of 3 times larger than the input (3 × 512 = 1536). This essentially means that every single token which is initially represented as 512-dimensional vector, now becomes 1536-dimensional. The idea behind this operation is that we want to allocate 512-dimensional vectors for each of the query, key and value later in the Scaled Dot-Product Attention operation (#(3)). Meanwhile, the second linear layer (self.linear) is configured to accept a token sequence where the dimensionality of each token is 512 (D_MODEL), and return another sequence with the exact same size. This layer will later be employed to combine the information from all attention heads.

Now let’s move on to the forward() method of the SelfAttention() class. Below is what it looks like.

# Codeblock 11
    def forward(self, x):
        print(f"originaltt: {x.shape}")

        x = self.qkv_linear(x)  #(1)
        print(f"after qkv_lineart: {x.shape}")

        x = x.reshape(BATCH_SIZE, SEQ_LENGTH, NUM_HEADS, 3*HEAD_DIM)  #(2)
        print(f"after reshapett: {x.shape}")

        x = x.permute(0, 2, 1, 3)  #(3)
        print(f"after permutett: {x.shape}")

        q, k, v = x.chunk(3, dim=-1)  #(4)
        print(f"qttt: {q.shape}")
        print(f"kttt: {k.shape}")
        print(f"vttt: {v.shape}")

        attn_output, attn_output_weights = self.attention(q, k, v, 
                                                          look_ahead_mask=self.look_ahead_mask) #(5)
        print(f"attn_outputtt: {attn_output.shape}")
        print(f"attn_output_weightst: {attn_output_weights.shape}")

        x = attn_output.permute(0, 2, 1, 3)  #(6)
        print(f"after permutett: {x.shape}")

        x = x.reshape(BATCH_SIZE, SEQ_LENGTH, NUM_HEADS*HEAD_DIM)  #(7)
        print(f"after reshapett: {x.shape}")

        x = self.linear(x)  #(8)
        print(f"after lineartt: {x.shape}")

        return x

Here we can see that the input tensor x is directly processed with the self.qkv_linear layer (#(1)). The resulting tensor is then reshaped to BATCH_SIZE × SEQ_LENGTH × NUM_HEADS × 3*HEAD_DIM as demonstrated at line #(2). Next, the permute() method is used to swap the SEQ_LENGTH and NUM_HEADS axes (#(3)). Such a reshaping and permutation process is actually a trick to distribute the 1536-dimensional token vectors into 8 attention heads, allowing them to be processed in parallel without needing to be separated into different tensors. Next, we use the chunk() method to divide the tensor into 3 parts, which will correspond to q, k and v (#(4)). One thing to keep in mind is that the division will operate on the last (token embedding) dimension, leaving the sequence length axis unchanged.

With the query, key, and value ready, we can now pass them all together through the Scaled Dot-Product Attention block (#(5)). Although the attention mechanism returns two tensors, in this case we will only bring attn_output to the next process since it is the one that actually contains the context-aware token sequence (recall that attn_output_weights is just a matrix containing the relationships between tokens). The next step to be done is to swap back the HEAD_DIM and the SEQ_LENGTH axes (#(6)) before eventually reshaping it back to the original dimension (#(7)). If you take a closer look at this line, you will see that NUM_HEADS is directly multiplied with HEAD_DIM. This operation effectively flattens the embeddings from the 8 attention heads back into a single dimension, which is conceptually similar to concatenating the output of each head together, as illustrated in Figure 7. Lastly, to actually combine the information from these 8 heads, we need to pass the tensor through another linear layer we discussed earlier (#(8)). We can think of the operation done in this linear layer as an approach to let the attention heads interacting with each other, which results in a better context understanding.

Let’s test the code by running the codeblock below. – By the way, here I re-run all previous codeblocks with the print() function commented out since I only want to focus on the flow of the SelfAttention() class we just created.

# Codeblock 12
self_attention = SelfAttention()

x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x = self_attention(x)
# Codeblock 12 output
original            : torch.Size([1, 200, 512])      #(1)
after qkv_linear    : torch.Size([1, 200, 1536])     #(2)
after reshape       : torch.Size([1, 200, 8, 192])   #(3)
after permute       : torch.Size([1, 8, 200, 192])   #(4)
q                   : torch.Size([1, 8, 200, 64])    #(5)
k                   : torch.Size([1, 8, 200, 64])    #(6)
v                   : torch.Size([1, 8, 200, 64])    #(7)
attn_output         : torch.Size([1, 8, 200, 64])    #(8)
attn_output_weights : torch.Size([1, 8, 200, 200])
after permute       : torch.Size([1, 200, 8, 64])    #(9)
after reshape       : torch.Size([1, 200, 512])      #(10)
after linear        : torch.Size([1, 200, 512])      #(11)

Based on the output above, we can see that our tensor successfully flows through all the layers. The input tensor, which initially has the size of 200×512 (#(1)), becomes 200×1536 thanks to the expansion done by the first linear layer (#(2)). The last dimension of the tensor is then distributed evenly into 8 attention heads, resulting in each head processing 192-dimensional token vectors (#(3)). The permutation done to swap the 8-attention head axis with the 200-sequence length axis is essentially just a method that allows PyTorch to do the computation for each head in parallel (#(4)). Next, at line #(5) to #(7) you can see that each of the q, k, and v tensors has the dimension of 200×64 for each head, which matches our discussion for Codeblock 9. After being processed with Attention() layer, we got the attn_output tensor which is then permuted (#(9)) and reshaped (#(10)) to the original input tensor dimension. – It is important to note that the permutation and reshaping operations need to be performed in this exact order because we initially changed its dimension by reshaping followed by a permutation. Technically, you could revert to the original dimension without permuting, but that would mess up your tensor elements. So, you really need to keep this in mind. – Finally, the last step to be done is to pass the tensor through the second linear layer in the SelfAttention() block, which does not change the tensor dimension at all (#(11)).


Multihead Cross-Attention

If the Self-Attention layer is used to capture the relationships between all tokens within the same sequence, the Cross-Attention layer captures relationships between tokens in two different sequences, i.e., between the translated sentence and the original sentence. By doing so, the model can obtain context from the original language for each token in the translated language. You can find this mechanism in the second Multihead Attention layer in the Decoder. Below is what it actually looks like.

Figure 12. The Multihead Cross-Attention block in the Decoder [1].
Figure 12. The Multihead Cross-Attention block in the Decoder [1].

You can see in Figure 12 above that the arrows coming into the attention layer are from different sources. The arrow on the left and middle are key and value coming from the Encoder, while the arrow on the right is the query from the Decoder itself. We can think of this mechanism as the Decoder querying information from the Encoder. Furthermore, remember that since the Encoder accepts and reads the entire sequence at once, hence we don’t need to implement look-ahead mask to this attention block so that it can access the full context from the original sequence even during the inference phase.

The implementation of a Cross-Attention layer is a little bit different from Self-Attention. As you can see in Codeblock 13 below, there are three linear layers to be implemented. The first one is self.kv_linear which is responsible to double the token embedding dimension of the tensor coming from the Encoder (#(1)). As you probably have guessed, the resulting tensor will later be divided into two, each representing key and value. The second linear layer is named self.q_linear, which the output tensor will act as the query (#(2)). Lastly, the role of the self.linear layer is the same as the one in Self-Attention, i.e., to combine the information from all attention heads without changing its dimension (#(3)).

# Codeblock 13
class CrossAttention(nn.Module):

    def __init__(self):
        super().__init__()

        self.kv_linear = nn.Linear(D_MODEL, 2*D_MODEL)  #(1)
        self.q_linear = nn.Linear(D_MODEL, D_MODEL)  #(2)
        self.attention = Attention()
        self.linear = nn.Linear(D_MODEL, D_MODEL)  #(3)

The forward() method of the CrossAttention() class accepts two inputs: x_enc and x_dec as shown at line #(1) in Codeblock 14, where the former input denotes the tensor coming from the Encoder, while the latter represents the one from the Decoder.

# Codeblock 14
    def forward(self, x_enc, x_dec):  #(1)
        print(f"x_enc originaltt: {x_enc.shape}")
        print(f"x_dec originaltt: {x_dec.shape}")

        x_enc = self.kv_linear(x_enc)  #(2)
        print(f"nafter kv_lineartt: {x_enc.shape}")

        x_enc = x_enc.reshape(BATCH_SIZE, SEQ_LENGTH, NUM_HEADS, 2*HEAD_DIM)  #(3)
        print(f"after reshapett: {x_enc.shape}")

        x_enc = x_enc.permute(0, 2, 1, 3)  #(4)
        print(f"after permutett: {x_enc.shape}")

        k, v = x_enc.chunk(2, dim=-1)  #(5)
        print(f"kttt: {k.shape}")
        print(f"vttt: {v.shape}")

        x_dec = self.q_linear(x_dec)  #(6)
        print(f"nafter q_lineartt: {x_dec.shape}")

        x_dec = x_dec.reshape(BATCH_SIZE, SEQ_LENGTH, NUM_HEADS, HEAD_DIM)  #(7)
        print(f"after reshapett: {x_dec.shape}")

        q = x_dec.permute(0, 2, 1, 3)  #(8)
        print(f"after permute (q)t: {q.shape}")

        attn_output, attn_output_weights = self.attention(q, k, v) #(9)
        print(f"nattn_outputtt: {attn_output.shape}")
        print(f"attn_output_weightst: {attn_output_weights.shape}")

        x = attn_output.permute(0, 2, 1, 3)
        print(f"after permutett: {x.shape}")

        x = x.reshape(BATCH_SIZE, SEQ_LENGTH, NUM_HEADS*HEAD_DIM)
        print(f"after reshapett: {x.shape}")

        x = self.linear(x)
        print(f"after lineartt: {x.shape}")

        return x

The x_enc and x_dec tensors are processed separately using similar steps to those in the Self-Attention layer, i.e., processing with linear layer, reshaping, and permuting. Notice that the processes done for these two input tensors are essentially the same. For example, line #(2) is equivalent to line #(6), line #(3) corresponds to line #(7), and line #(4) matches line #(8). We apply the chunk() method to split the x_enc tensor into key and value (#(5)), whereas in the case of x_dec, we don’t need to do so as it will directly serve as the query tensor. Next, we feed q, k, and v into the Scaled Dot-Product Attention layer (#(9)). This is actually the step where the information from the Encoder is queried by the Decoder. Additionally, keep in mind that here we should not pass the look-ahead mask parameter since we want to leave the attention weights unmasked. Next, I don’t think I need to explain the remaining steps because these are all exactly the same as the one in the Multihead Self-Attention mechanism, which we already discussed in the previous section.

Now let’s test our CrossAttention() class by passing dummy x_enc and x_dec tensors. See Codeblock 15 and the output below for the details.

# Codeblock 15
cross_attention = CrossAttention()

x_enc = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x_dec = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)

x = cross_attention(x_enc, x_dec)
# Codeblock 15 output
x_enc original      : torch.Size([1, 200, 512])     #(1)
x_dec original      : torch.Size([1, 200, 512])     #(2)

after kv_linear     : torch.Size([1, 200, 1024])    #(3)
after reshape       : torch.Size([1, 200, 8, 128])
after permute       : torch.Size([1, 8, 200, 128])
k                   : torch.Size([1, 8, 200, 64])   #(4)
v                   : torch.Size([1, 8, 200, 64])   #(5)

after q_linear      : torch.Size([1, 200, 512])
after reshape       : torch.Size([1, 200, 8, 64])
after permute (q)   : torch.Size([1, 8, 200, 64])   #(6)

attn_output         : torch.Size([1, 8, 200, 64])   #(7)
attn_output_weights : torch.Size([1, 8, 200, 200])
after permute       : torch.Size([1, 200, 8, 64])
after reshape       : torch.Size([1, 200, 512])     #(8)
after linear        : torch.Size([1, 200, 512])     #(9)

Initially, both x_enc and x_dec tensors have the exact same dimensions, as shown at line #(1) and #(2) in the above output. After being passed through the self.kv_linear layer, the embedding dimension of x_enc expands from 512 to 1024 (#(3)). This means that each token is now represented by a 1024-dimensional vector. This tensor is then reshaped, permuted, and chunked, so that it becomes k and v. At this point the embedding dimensions of these two tensors are already split into 8 attention heads, ready to be used as the input for the Scaled Dot-Product Attention layer (#(4) and #(5)). We do also apply the reshaping and permuting steps to x_dec, yet we omit the chunking process since this entire tensor will act as the q (#(6)). As the process is done, now that the q, k, and v tensors are having the exact same dimensions, which is the same as what we have earlier in the SelfAttention() block. Processing with the self.attention layer results in attn_output tensor (#(7)) which will later be permuted and reshaped back to the initial tensor dimension (#(8)). Finally, after being processed with self.linear layer (#(9)), our tensor is now containing the translated language which already has the contextual information from the original language.


Feed Forward Blocks

Our previous discussion about attention mechanism was quite intense, especially for those who have never heard about this before – well, at least for me when I first tried to understand this idea. – To give our brain a little bit of rest, let’s shift our focus to the simplest component of Transformer: the Feed Forward block.

In the Transformer architecture, you will find two identical Feed Forward blocks – one in the Encoder and another one in the Decoder. Take a look at Figure 13 below to see where they are located. By implementing Feed Forward blocks like this, the depth of the network will increase, and so does the number of learnable parameters. This allows the network to capture more complex patterns in the data so that it does not rely solely on the information extracted by the attention blocks.

Figure 13. The Feed Forward blocks (highlighted in blue) in the Encoder (left) and Decoder (right) [1].
Figure 13. The Feed Forward blocks (highlighted in blue) in the Encoder (left) and Decoder (right) [1].

Each of the two Feed Forward blocks above consists of a stack of two linear layers with a ReLU activation function and a dropout layer in between. The implementation of this structure is very easy, as you can just stack these layers one after another like what I do in Codeblock 16 below.

# Codeblock 16
class FeedForward(nn.Module):

    def __init__(self):
        super().__init__()

        self.linear_0 = nn.Linear(D_MODEL, HIDDEN_DIM)  #(1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=DROP_PROB)  #(2)
        self.linear_1 = nn.Linear(HIDDEN_DIM, D_MODEL)  #(3)

    def forward(self, x):
        print(f"originalt: {x.shape}")

        x = self.linear_0(x)
        print(f"after linear_0t: {x.shape}")

        x = self.relu(x)
        print(f"after relut: {x.shape}")

        x = self.dropout(x)
        print(f"after dropoutt: {x.shape}")

        x = self.linear_1(x)
        print(f"after linear_1t: {x.shape}")

        return x

There are several things I want to emphasize in the above codeblock. First the self.linear_0 layer is configured to accept a tensor of size D_MODEL (512) and expands it to HIDDEN_DIM (2048) as shown at line #(1). As I’ve mentioned earlier, we do this so that the model can extract more information from the dataset. This tensor dimension will eventually be shrunk back down to 512 by the self.linear_1 layer (#(3)), which helps keep the input and output dimensions consistent throughout the network. Next, here we set the rate of our dropout layer to DROP_PROB (0.1) (#(2)) according to the configuration table provided in Figure 2. As for the forward() method, I won’t go into the details as it simply connects the layers we initialized in the __init__() method.

As usual, here I test the FeedForward() class by passing through a tensor of size BATCH_SIZE×SEQ_LENGTH×D_MODEL as shown in Codeblock 17.

# Codeblock 17
feed_forward = FeedForward()

x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x = feed_forward(x)
# Codeblock 17 output
original       : torch.Size([1, 200, 512])
after linear_0 : torch.Size([1, 200, 2048])  #(1)
after relu     : torch.Size([1, 200, 2048])
after dropout  : torch.Size([1, 200, 2048])
after linear_1 : torch.Size([1, 200, 512])   #(2)

We can see in the output above that the dimensionality of each token, which is initially 512, becomes 2048 thanks to the self.linear_0 layer (#(1)). This tensor size remains unchanged until the dropout layer before eventually squeezed back to 512 by the self.linear_1 layer (#(2)).


Layer Normalization

The last Transformer component I want to talk about is the one that you can see throughout the entire Encoder and Decoder, namely the Add & Norm block (colored in yellow in Figure 14).

Figure 14. The Encoder contains two Add & Norm blocks, while the Decoder has three [1].
Figure 14. The Encoder contains two Add & Norm blocks, while the Decoder has three [1].

As the name suggests, this block essentially comprises an element-wise addition and a layer normalization operation. However, for the sake of simplicity, in this section I will only focus on the normalization process. The element-wise addition will later be discussed when we assemble the entire Transformer architecture.

The purpose of implementing layer normalization in this case is to normalize the tensor right after being processed by the preceding block. Keep in mind that what we use here is layer normalization, not batch normalization. In case you’re not yet familiar with Layer Norm, it essentially performs normalization where the statistics (i.e., mean and variance) are computed across the features (embedding dimensions) for each individual token. This is essentially the reason that in Figure 15 the color that I use for all embedding dimensions is the same for each token. On the other hand, in Batch Norm, the cells that have the same color spans across its batch and sequence dimension, indicating that the mean and variance are computed based on these axes.

Figure 15. An illustration of how Batch Normalization and Layer Normalization work on a batch of four sequences, each consisting of 8 tokens represented as 6-dimensional embedding vectors [2][3].
Figure 15. An illustration of how Batch Normalization and Layer Normalization work on a batch of four sequences, each consisting of 8 tokens represented as 6-dimensional embedding vectors [2][3].

You can see the implementation of a layer normalization mechanism in Codeblock 18. There are several variables that I need to initialize within the __init__() method of the LayerNorm() class. First, there is a small number called epsilon (#(1)), which we need to define to prevent a division-by-zero error that potentially occur at line #(8). Next, we also need to initialize gamma (#(2)) and beta (#(3)). These two variables can be thought of as weight and bias in linear regression, where the gamma is responsible to scale the normalized output, whereas beta is for shifting it. By understanding this property, if we set gamma to be fixed to 1 and the beta to 0, then the normalized output values won’t change. However, although we indeed use these two numbers for the initial gamma and beta, yet I set the requires_grad parameter to True so that they will get updated as the training goes.

# Codeblock 18
class LayerNorm(nn.Module):
    def __init__(self, eps=1e-5):
        super().__init__()
        self.eps = eps  #(1)
        self.gamma = nn.Parameter(torch.ones(D_MODEL), requires_grad=True)  #(2)
        self.beta = nn.Parameter(torch.zeros(D_MODEL), requires_grad=True)  #(3)

    def forward(self, x):  #(4)
        print(f"originalt: {x.shape}")

        mean = x.mean(dim=[-1], keepdim=True)  #(5)
        print(f"meantt: {mean.shape}")

        var = ((x - mean) ** 2).mean(dim=[-1], keepdim=True)  #(6)
        print(f"vartt: {var.shape}")

        stddev = (var + self.eps).sqrt()  #(7)
        print(f"stddevtt: {stddev.shape}")

        x = (x - mean) / stddev  #(8)
        print(f"normalizedt: {x.shape}")

        x = (self.gamma * x) + self.beta  #(9)
        print(f"after scaling and shiftingt: {x.shape}")

        return x

To the forward() method, it initially works by accepting tensor x (#(4)). Afterwards, we calculate the mean (#(5)) and variance (#(6)) from it. Remember that because we want to compute these statistics from each row, hence we need to use dim=[-1] (since the embedding dimension is the last axis of the tensor). Next, we calculate the standard deviation (#(7)) so that the normalized tensor can be obtained (#(8)). Lastly, this normalized tensor will be rescaled using self.gamma and self.beta as shown at line #(9).

As the LayerNorm() class has successfully been constructed, now that we will run the following codeblock to check if our implementation is correct.

# Codeblock 19
layer_norm = LayerNorm()

x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x = layer_norm(x)
# Codeblock 19 output
original    : torch.Size([1, 200, 512])
mean        : torch.Size([1, 200, 1])
var     : torch.Size([1, 200, 1])
stddev      : torch.Size([1, 200, 1])
normalized  : torch.Size([1, 200, 512])
after scaling and shifting  : torch.Size([1, 200, 512])

We can see in the output above that all processes done inside the LayerNorm() class do not alter the tensor dimension at all. The mean, var, and stddev are just the statistics that we compute for each row (token), hence the embedding dimension collapses to 1 for these tensors. By the way, in case you’re wondering why we use keepdim=True, it is because setting it to False would result in mean, var, and stddev having the dimension of 1×200 rather than 1×200×1, which causes these tensors to be incompatible for the subsequent operations.


The Entire Transformer Architecture

At this point we have successfully created all components for the Transformer architecture, so they are now ready to be assembled. We will start by assembling the Encoder, followed by the Decoder, and finally, I will connect the two as well as the other remaining components.

Encoder

There are four blocks required to be placed sequentially in the Encoder, namely the Multihead Self-Attention, Layer Norm, Feed Forward, and another Layer Norm. Additionally, there are also two residual connections that skip over the Multihead Self-Attention block and the Feed Forward block. See the detailed structure in Figure 16 below.

Figure 16. The Encoder block [1].
Figure 16. The Encoder block [1].

Now let’s discuss the implementation in the following codeblock.

# Codeblock 20
class Encoder(nn.Module):

    def __init__(self):
        super().__init__()

        self.self_attention = SelfAttention(look_ahead_mask=False)  #(1)
        self.dropout_0 = nn.Dropout(DROP_PROB)  #(2)
        self.layer_norm_0 = LayerNorm()         #(3)
        self.feed_forward = FeedForward()
        self.dropout_1 = nn.Dropout(DROP_PROB)  #(4)
        self.layer_norm_1 = LayerNorm()         #(5)

    def forward(self, x):
        residual = x
        print(f"original &amp; residualt: {x.shape}")

        x = self.self_attention(x)  #(6)
        print(f"after self attentiont: {x.shape}")

        x = self.dropout_0(x)  #(7)
        print(f"after dropouttt: {x.shape}")

        x = self.layer_norm_0(x + residual)  #(8)
        print(f"after layer normt: {x.shape}")

        residual = x
        print(f"nx &amp; residualtt: {x.shape}")

        x = self.feed_forward(x)  #(9)
        print(f"after feed forwardt: {x.shape}")

        x = self.dropout_1(x)
        print(f"after dropouttt: {x.shape}")

        x = self.layer_norm_1(x + residual)
        print(f"after layer normt: {x.shape}")

        return x

I initialize the four blocks mentioned earlier in the __init__() method of the Encoder() class. Remember that since the Encoder reads the entire input sequence at once, hence we need to set the look_ahead_mask parameter to False so that every single token can attend to all other tokens (#(1)). Next, the two Layer Norm blocks are initialized separately, which I name them self.layer_norm_0 and self.layer_norm_1 as shown at line #(3) and #(5). Here I also initialize two dropout layers at line #(2) and #(4) which will later be placed just before each normalization block.

In the forward() method, we first copy the x tensor to the residual variable, so that we can process x with the Multihead Self-Attention layer (#(6)) without affecting the original tensor. Next, we pass the resulting output through the first dropout layer (#(7)). Note that the Layer Norm block doesn’t just use the output tensor from the dropout layer. Instead, we also need to inject the residual tensor to x by element-wise addition before applying the normalization step (#(8)). Afterwards, we repeat the same processes, except that this time we replace the Multihead Self-Attention block with the Feed Forward network (#(9)).

If you remember the classes I created earlier, you will notice that all of them, – specifically the ones intended to be placed inside the Encoder and Decoder, – have the exact same input and output dimension. We can check this by passing a tensor through the entire Encoder architecture as shown in Codeblock 21 below. You will see in the output that the tensor size at each process is exactly the same.

# Codeblock 21
encoder = Encoder()

x = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x = encoder(x)
# Codeblock 21 output
original &amp; residual   : torch.Size([1, 200, 512])
after self attention  : torch.Size([1, 200, 512])
after dropout         : torch.Size([1, 200, 512])
after layer norm      : torch.Size([1, 200, 512])

x &amp; residual          : torch.Size([1, 200, 512])
after feed forward    : torch.Size([1, 200, 512])
after dropout         : torch.Size([1, 200, 512])
after layer norm      : torch.Size([1, 200, 512])

Decoder

The Decoder architecture, which you can see in Figure 17, is a little bit longer than the Encoder. Initially, the tensor passed into it will be processed with a Masked Multihead Self-Attention layer. Next, we send the resulting tensor as the query input for the subsequent Multihead Cross-Attention layer. The key and value input for this layer will be obtained from the Encoder output. Lastly, we propagate the tensor through the Feed Forward block. Remember that here we will also implement the layer normalization operations as well as the residual connections.

Figure 17. The Decoder block [1].
Figure 17. The Decoder block [1].

Talking about the implementation in Codeblock 22, we need to initialize two attention blocks inside the __init__() method. The first one is SelfAttention() with look_ahead_mask=True (#(1)), and the second one is CrossAttention() (#(3)). Here I will also apply the dropout layers which I initialize at line #(2), #(4) and #(5).

# Codeblock 22
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.self_attention = SelfAttention(look_ahead_mask=True)  #(1)
        self.dropout_0 = nn.Dropout(DROP_PROB)  #(2)
        self.layer_norm_0 = LayerNorm()

        self.cross_attention = CrossAttention()  #(3)
        self.dropout_1 = nn.Dropout(DROP_PROB)  #(4)
        self.layer_norm_1 = LayerNorm()

        self.feed_forward = FeedForward()
        self.dropout_2 = nn.Dropout(DROP_PROB)  #(5)
        self.layer_norm_2 = LayerNorm()

    def forward(self, x_enc, x_dec):  #(6)
        residual = x_dec
        print(f"x_dec &amp; residualt: {x_dec.shape}")

        x_dec = self.self_attention(x_dec)  #(7)
        print(f"after self attentiont: {x_dec.shape}")

        x_dec = self.dropout_0(x_dec)
        print(f"after dropouttt: {x_dec.shape}")

        x_dec = self.layer_norm_0(x_dec + residual)  #(8)
        print(f"after layer normt: {x_dec.shape}")

        residual = x_dec
        print(f"nx_dec &amp; residualt: {x_dec.shape}")

        x_dec = self.cross_attention(x_enc, x_dec)  #(9)
        print(f"after cross attentiont: {x_dec.shape}")

        x_dec = self.dropout_1(x_dec)
        print(f"after dropouttt: {x_dec.shape}")

        x_dec = self.layer_norm_1(x_dec + residual)
        print(f"after layer normt: {x_dec.shape}")

        residual = x_dec
        print(f"nx_dec &amp; residualt: {x_dec.shape}")

        x_dec = self.feed_forward(x_dec)  #(10)
        print(f"after feed forwardt: {x_dec.shape}")

        x_dec = self.dropout_2(x_dec)
        print(f"after dropouttt: {x_dec.shape}")

        x_dec = self.layer_norm_2(x_dec + residual)
        print(f"after layer normt: {x_dec.shape}")

        return x_dec

Meanwhile, the forward() method, even though it is basically just a stack of layers placed one after another, but there are several things I need to highlight. First, this method accepts two input parameters: x_enc and x_dec (#(6)). As the name suggests, the former is the tensor coming from the Encoder, while the latter is the one we obtain from the previous layer in the Decoder. We initially only work with the x_dec tensor, which is processed using the first attention (#(7)) and layer normalization (#(8)) blocks. As this process is done, we now use x_enc alongside the processed x_dec as the input for the cross_attention layer (#(9)), which is where our model fuses information from the Encoder and the Decoder. Lastly, the resulting output will be fed into the Feed Forward block (#(10)).

We do the testing by passing two tensors of the same dimensions to simulate the actual x_enc and x_dec. Based on the output of the following codeblock, we can see that these two tensors successfully pass through the entire processes, indicating that we have constructed the Decoder correctly.

# Codeblock 23
decoder = Decoder()

x_enc = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)
x_dec = torch.randn(BATCH_SIZE, SEQ_LENGTH, D_MODEL)

x = decoder(x_enc, x_dec)
# Codeblock 23 output
x_dec &amp; residual      : torch.Size([1, 200, 512])
after self attention  : torch.Size([1, 200, 512])
after dropout         : torch.Size([1, 200, 512])
after layer norm      : torch.Size([1, 200, 512])

x_dec &amp; residual      : torch.Size([1, 200, 512])
after cross attention : torch.Size([1, 200, 512])
after dropout         : torch.Size([1, 200, 512])
after layer norm      : torch.Size([1, 200, 512])

x_dec &amp; residual      : torch.Size([1, 200, 512])
after feed forward    : torch.Size([1, 200, 512])
after dropout         : torch.Size([1, 200, 512])
after layer norm      : torch.Size([1, 200, 512])

Combining Encoder and Decoder

As we have successfully created the Encoder() and Decoder() class, we can now get into the very last part of this writing: connecting the Encoder to the Decoder along with the other components that interact with them. Here I provide you a figure showing the entire Transformer architecture for reference, so you don’t need to scroll all the way to Figure 1 just to verify our implementation in Codeblock 24 and 25.

Figure 18. The Transformer architecture (copied from Figure 1) [1].
Figure 18. The Transformer architecture (copied from Figure 1) [1].

In the codeblock below, I implement the architecture inside the Transformer() class. In the __init__() method, we initialize the input and output embedding layers (#(1) and #(2)). These two layers are responsible for converting tokens into their corresponding 512-dimensional vector representations. Next, we initialize a single positional_encoding layer which will be used twice: one for the embedded input tokens, and another one for the embedded output tokens. Meanwhile, the initialization of the Encoder (#(4)) and the Decoder (#(5)) blocks is a little bit different, where in this case we utilize nn.ModuleList(). We can think of this like a list of modules which we will connect sequentially later in the forward pass, and in this case each is repeated N (6) times. In fact, this is essentially why I name them self.encoders and self.decoders (with s). The last thing we need to do in the __init__() method is to initialize the self.linear layer (#(6)), in which it will be responsible to map the 512-dimensional token embeddings to all possible tokens in the destination language. We can perceive this like a classification task, where the model will choose one token at a time as the prediction result based on their probability scores.

# Codeblock 24
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()

        self.input_embedding = InputEmbedding()  #(1)
        self.output_embedding = OutputEmbedding()  #(2)

        self.positional_encoding = PositionalEncoding()  #(3)

        self.encoders = nn.ModuleList([Encoder() for _ in range(N)])  #(4)
        self.decoders = nn.ModuleList([Decoder() for _ in range(N)])  #(5)

        self.linear = nn.Linear(D_MODEL, VOCAB_SIZE_DST)  #(6)

The way our forward() method works is a little bit unusual. Remember that the entire Transformer accepts two inputs: a sequence from the original language, and the shifted-right sequence from the translated language. Hence, in Codeblock 25 below, you will see that this method accepts two sequences: x_enc_raw and x_dec_raw (#(1)). The _raw suffix I use indicates that it is a raw token sequence, i.e., a sequence of integers, not the tokens that have been converted into 512-dimensional vectors. This conversion will then be done at line #(2) and #(5). Afterwards, we will inject positional encoding to the resulting tensors by element-wise addition, which is done at line #(3) for the sequence to be fed into the Encoder, and #(6) for the one to be passed through the Decoder. Next, we use a loop to feed the output of an Encoder block into the subsequent one sequentially (#(4)). We also do the similar thing to the Decoder blocks, except that each of these accepts both x_enc and x_dec (#(7)). What you need to notice at this point is that the x_enc to be fed into the Decoder block is only the one coming out from the last Encoder block. Meanwhile, the x_dec tensor to be fed into the next Decoder is always the one produced by the previous Decoder block. – You can verify this by taking a closer look at line #(7), where x_dec is updated at each iteration while x_enc is not. – Lastly, once the Decoder loop is completed, we will pass the resulting tensor to the linear layer (#(8)). If you take a look at Figure 18, you will notice that there is a softmax layer placed after this linear layer. However, we won’t implement it here because in PyTorch it is already included in the loss function.

# Codeblock 25
    def forward(self, x_enc_raw, x_dec_raw):  #(1)
        print(f"x_enc_rawtt: {x_enc_raw.shape}")
        print(f"x_dec_rawtt: {x_dec_raw.shape}")

        # Encoder
        x_enc = self.input_embedding(x_enc_raw)  #(2)
        print(f"nafter input embeddingt: {x_enc.shape}")

        x_enc = x_enc + self.positional_encoding()  #(3)
        print(f"after pos encodingt: {x_enc.shape}")

        for i, encoder in enumerate(self.encoders):
            x_enc = encoder(x_enc)  #(4)
            print(f"after encoder #{i}t: {x_enc.shape}")

        # Decoder
        x_dec = self.output_embedding(x_dec_raw)  #(5)
        print(f"nafter output embeddingt: {x_dec.shape}")

        x_dec = x_dec + self.positional_encoding()  #(6)
        print(f"after pos encodingt: {x_dec.shape}")

        for i, decoder in enumerate(self.decoders):
            x_dec = decoder(x_enc, x_dec)  #(7)
            print(f"after decoder #{i}t: {x_dec.shape}")

        x = self.linear(x_dec)  #(8)
        print(f"nafter lineartt: {x.shape}")

        return x

As the Transformer() class is completed, now that we can test it with the following codeblock. You can see in the resulting output that our x_enc_raw and x_dec_raw successfully passed through the entire Transformer architecture, which essentially means that our network is finally ready to be trained for seq2seq tasks.

# Codeblock 26
transformer = Transformer()

x_enc_raw = torch.randint(0, VOCAB_SIZE_SRC, (BATCH_SIZE, SEQ_LENGTH))
x_dec_raw = torch.randint(0, VOCAB_SIZE_DST, (BATCH_SIZE, SEQ_LENGTH))

y = transformer(x_enc_raw, x_dec_raw).shape
# Codeblock 26 output
x_enc_raw              : torch.Size([1, 200])
x_dec_raw              : torch.Size([1, 200])

after input embedding  : torch.Size([1, 200, 512])
after pos encoding     : torch.Size([1, 200, 512])
after encoder #0       : torch.Size([1, 200, 512])
after encoder #1       : torch.Size([1, 200, 512])
after encoder #2       : torch.Size([1, 200, 512])
after encoder #3       : torch.Size([1, 200, 512])
after encoder #4       : torch.Size([1, 200, 512])
after encoder #5       : torch.Size([1, 200, 512])

after output embedding : torch.Size([1, 200, 512])
after pos encoding     : torch.Size([1, 200, 512])
after decoder #0       : torch.Size([1, 200, 512])
after decoder #1       : torch.Size([1, 200, 512])
after decoder #2       : torch.Size([1, 200, 512])
after decoder #3       : torch.Size([1, 200, 512])
after decoder #4       : torch.Size([1, 200, 512])
after decoder #5       : torch.Size([1, 200, 512])

after linear           : torch.Size([1, 200, 120])

Talking more specifically about the flow, you can see here that the tensor dimensions throughout the entire Encoder and Decoder blocks are consistent. This kind of property allows us to scale the model easily. So, for example, if we want to increase the model complexity to improve its ability in understanding larger dataset, we can simply stack more Encoders and Decoders. Or, if you want the model to be more efficient, you can just decrease the number of these blocks. – In fact, not only the number of Encoders and Decoders, but you can basically change the value of all parameters defined in Codeblock 2 according to your needs.

The following code is an optional step, but in case you’re wondering what the overall structure of the Transformer architecture looks like, you can just run it.

# Codeblock 27
transformer = Transformer()
summary(transformer, input_data=(x_enc_raw, x_dec_raw))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Transformer                              [1, 200, 120]             --
├─InputEmbedding: 1-1                    [1, 200, 512]             --
│    └─Embedding: 2-1                    [1, 200, 512]             51,200
├─PositionalEncoding: 1-2                [200, 512]                --
├─ModuleList: 1-3                        --                        --
│    └─Encoder: 2-2                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-1           [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-2                 [1, 200, 512]             --
│    │    └─LayerNorm: 3-3               [1, 200, 512]             1,024
│    │    └─FeedForward: 3-4             [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-5                 [1, 200, 512]             --
│    │    └─LayerNorm: 3-6               [1, 200, 512]             1,024
│    └─Encoder: 2-3                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-7           [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-8                 [1, 200, 512]             --
│    │    └─LayerNorm: 3-9               [1, 200, 512]             1,024
│    │    └─FeedForward: 3-10            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-11                [1, 200, 512]             --
│    │    └─LayerNorm: 3-12              [1, 200, 512]             1,024
│    └─Encoder: 2-4                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-13          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-14                [1, 200, 512]             --
│    │    └─LayerNorm: 3-15              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-16            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-17                [1, 200, 512]             --
│    │    └─LayerNorm: 3-18              [1, 200, 512]             1,024
│    └─Encoder: 2-5                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-19          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-20                [1, 200, 512]             --
│    │    └─LayerNorm: 3-21              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-22            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-23                [1, 200, 512]             --
│    │    └─LayerNorm: 3-24              [1, 200, 512]             1,024
│    └─Encoder: 2-6                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-25          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-26                [1, 200, 512]             --
│    │    └─LayerNorm: 3-27              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-28            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-29                [1, 200, 512]             --
│    │    └─LayerNorm: 3-30              [1, 200, 512]             1,024
│    └─Encoder: 2-7                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-31          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-32                [1, 200, 512]             --
│    │    └─LayerNorm: 3-33              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-34            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-35                [1, 200, 512]             --
│    │    └─LayerNorm: 3-36              [1, 200, 512]             1,024
├─OutputEmbedding: 1-4                   [1, 200, 512]             --
│    └─Embedding: 2-8                    [1, 200, 512]             61,440
├─PositionalEncoding: 1-5                [200, 512]                --
├─ModuleList: 1-6                        --                        --
│    └─Decoder: 2-9                      [1, 200, 512]             --
│    │    └─SelfAttention: 3-37          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-38                [1, 200, 512]             --
│    │    └─LayerNorm: 3-39              [1, 200, 512]             1,024
│    │    └─CrossAttention: 3-40         [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-41                [1, 200, 512]             --
│    │    └─LayerNorm: 3-42              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-43            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-44                [1, 200, 512]             --
│    │    └─LayerNorm: 3-45              [1, 200, 512]             1,024
│    └─Decoder: 2-10                     [1, 200, 512]             --
│    │    └─SelfAttention: 3-46          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-47                [1, 200, 512]             --
│    │    └─LayerNorm: 3-48              [1, 200, 512]             1,024
│    │    └─CrossAttention: 3-49         [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-50                [1, 200, 512]             --
│    │    └─LayerNorm: 3-51              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-52            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-53                [1, 200, 512]             --
│    │    └─LayerNorm: 3-54              [1, 200, 512]             1,024
│    └─Decoder: 2-11                     [1, 200, 512]             --
│    │    └─SelfAttention: 3-55          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-56                [1, 200, 512]             --
│    │    └─LayerNorm: 3-57              [1, 200, 512]             1,024
│    │    └─CrossAttention: 3-58         [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-59                [1, 200, 512]             --
│    │    └─LayerNorm: 3-60              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-61            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-62                [1, 200, 512]             --
│    │    └─LayerNorm: 3-63              [1, 200, 512]             1,024
│    └─Decoder: 2-12                     [1, 200, 512]             --
│    │    └─SelfAttention: 3-64          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-65                [1, 200, 512]             --
│    │    └─LayerNorm: 3-66              [1, 200, 512]             1,024
│    │    └─CrossAttention: 3-67         [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-68                [1, 200, 512]             --
│    │    └─LayerNorm: 3-69              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-70            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-71                [1, 200, 512]             --
│    │    └─LayerNorm: 3-72              [1, 200, 512]             1,024
│    └─Decoder: 2-13                     [1, 200, 512]             --
│    │    └─SelfAttention: 3-73          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-74                [1, 200, 512]             --
│    │    └─LayerNorm: 3-75              [1, 200, 512]             1,024
│    │    └─CrossAttention: 3-76         [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-77                [1, 200, 512]             --
│    │    └─LayerNorm: 3-78              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-79            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-80                [1, 200, 512]             --
│    │    └─LayerNorm: 3-81              [1, 200, 512]             1,024
│    └─Decoder: 2-14                     [1, 200, 512]             --
│    │    └─SelfAttention: 3-82          [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-83                [1, 200, 512]             --
│    │    └─LayerNorm: 3-84              [1, 200, 512]             1,024
│    │    └─CrossAttention: 3-85         [1, 200, 512]             1,050,624
│    │    └─Dropout: 3-86                [1, 200, 512]             --
│    │    └─LayerNorm: 3-87              [1, 200, 512]             1,024
│    │    └─FeedForward: 3-88            [1, 200, 512]             2,099,712
│    │    └─Dropout: 3-89                [1, 200, 512]             --
│    │    └─LayerNorm: 3-90              [1, 200, 512]             1,024
├─Linear: 1-7                            [1, 200, 120]             61,560
==========================================================================================
Total params: 44,312,696
Trainable params: 44,312,696
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.28
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 134.54
Params size (MB): 177.25
Estimated Total Size (MB): 311.79
==========================================================================================

Ending

And that’s all for today’s tutorial about Transformer and its PyTorch implementation. I would like to congratulate those who followed through all the discussions above, as you’ve spent more than 40 minutes to read this article! By the way, feel free to comment if you spot any mistake in my explanation or the code.

I hope you find this article useful. Thanks for reading, and see ya in the next one!

_P.S. Here’s the link to the GitHub repository._


References

[1] Ashish Vaswani et al. Attention Is All You Need. Arxiv. https://arxiv.org/pdf/1706.03762 [Accessed September 29, 2024].

[2] Image created originally by author.

[3] Sheng Shen et al. PowerNorm: Rethinking Batch Normalization in Transformers. Arxiv. https://arxiv.org/abs/2003.07845 [Accessed October 3, 2024].

The post Paper Walkthrough: Attention Is All You Need appeared first on Towards Data Science.

]]>
Paper Walkthrough: U-Net https://towardsdatascience.com/paper-walkthrough-u-net-98877a2cd33c/ Fri, 20 Sep 2024 16:43:59 +0000 https://towardsdatascience.com/paper-walkthrough-u-net-98877a2cd33c/ A PyTorch implementation on one of the most popular semantic segmentation models.

The post Paper Walkthrough: U-Net appeared first on Towards Data Science.

]]>
Introduction to U-Net

When we talk about image segmentation, we should not forget about U-Net, a neural network architecture that was first proposed by Ronneberger et al. [1] back in 2015. This model was initially intended to perform segmentation tasks on medical images. Later on, other researchers found that this architecture could actually be used for general semantic segmentation tasks as well. Furthermore, it is also possible to utilize the model for other things like super resolution (i.e., upscaling low resolution image into a higher one) and diffusion (i.e., generating images from noises). In this article, I would like to show you how to implement U-Net from scratch using PyTorch. You can see the entire U-Net architecture in Figure 1. By looking at this structure, I think it is pretty straightforward how this network got its name.

Figure 1. The U-Net architecture [1].
Figure 1. The U-Net architecture [1].

There are several key components in the architecture. First, there is a Contracting Path, which is also known as the Encoder. This component is responsible for gradually shrinking the spatial dimension of the input image from 572×572 to 64×64. However, notice that the number of channels in each downsampling stage doubles instead to compensate the information loss caused by the spatial dimension reduction. In contrast, the Expansive Path (Decoder) expands the feature map into a larger spatial dimension while reducing the number of channels. Despite the symmetrical architecture, it is important to note that the output produced by the final upsampling stage is different from the input in terms of the image resolution.

There are two types of connections connecting the Encoder and the Decoder, namely the Bottleneck and the Residual-Like path. The Bottleneck part of the network corresponds to everything between the last pooling layer and the first transpose convolution layer, i.e., the lowermost part of the network shown in Figure 1. Meanwhile, the Residual-Like paths are the gray-colored arrows that help the network to preserve high-resolution features from the Encoder (since relying solely on the Bottleneck would result in significant loss of spatial information). Additionally, the reason that I name it Residual-Like is essentially because it differs from the one proposed in ResNet. In that architecture, we perform element-wise summation in the merging process, whereas in the case of U-Net we concatenate the two tensors instead.


Implementing U-Net with PyTorch

There are three imports that I do for this project: the base PyTorch module (torch) for standard mathematical functionalities, the nn submodule for loading neural network layers, and the summary() function taken from torchinfo to print out the details of a model.

# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

The Encoder

Now that the modules have been successfully loaded, we can actually start coding. Let’s begin with the Encoder first. In Figure 2 below, all components belong to the Encoder are highlighted in green. If you take a closer look at these Encoder stages, you can see that each of those comprises of two consecutive convolution layers with 3×3 kernels followed by a 2×2 maximum-pooling layer. Since the process done in all these stages are basically the same, we can just wrap each stage together and repeat the process four times.

Figure 2. The Encoder part of U-Net comprises of four downsampling stages [2].
Figure 2. The Encoder part of U-Net comprises of four downsampling stages [2].

The stack of two convolution layers is implemented in the DoubleConv() class shown in Codeblock 2 and 3 below. Here I initialize both layers (conv_0 and conv_1) as well as their corresponding batch normalization layers (bn_0 and bn_1). In fact, the use of batch normalization is not mentioned in the original U-Net paper. However, I will implement it anyway since it usually allows the model to obtain better accuracy. In this article I will only focus on demonstrating how to implement the U-Net architecture. So, if you want to actually train this model, it will be a good idea if you do it both with and without batch normalization to see if my hypothesis is correct. Furthermore, if you decide not to use batch normalization, ensure that you change the bias parameter of the convolution layers to True (at the line marked with #(1) and #(2) in Codeblock 2). This is essentially because if you use batch normalization, there will be no point of using bias term for the convolution as the normalization layer will cancel the biases out.

# Codeblock 2
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv_0 = nn.Conv2d(in_channels=in_channels, 
                                out_channels=out_channels, 
                                kernel_size=3, bias=False)    #(1)
        self.bn_0 = nn.BatchNorm2d(num_features=out_channels)

        self.conv_1 = nn.Conv2d(in_channels=out_channels,
                                out_channels=out_channels, 
                                kernel_size=3, bias=False)    #(2)
        self.bn_1 = nn.BatchNorm2d(num_features=out_channels)

        self.relu = nn.ReLU(inplace=True)

As all layers as well as the ReLU activation function have been initialized, now we need to string them together with the forward() function. Just a quick reminder: the correct layer sequence when working with CNNs is Conv-BN-ReLU, and this is exactly the structure that I implement below. In order to make the process clearer, here I also print out the tensor dimension after each convolution operation.

# Codeblock 3
    def forward(self, x):
        print(f'originaltt: {x.size()}')

        x = self.conv_0(x)
        x = self.bn_0(x)
        x = self.relu(x)
        print(f'after first convt: {x.size()}')

        x = self.conv_1(x)
        x = self.bn_1(x)
        x = self.relu(x)
        print(f'after second convt: {x.size()}')

        return x

We can run the Codeblock 4 to check whether our DoubleConv() works properly. Here I set the network to accept 1-channel image and output 64-channel image as written at the line marked with #(1). The tensor of random numbers (#(2)), which we assume it to be an image, has the dimension of 1×1×572×572. Each axis of this tensor represents the number of images in a single batch, the number of color channels, image height and image width, respectively.

# Codeblock 4
double_conv = DoubleConv(in_channels=1, out_channels=64)    #(1)
x = torch.randn((1, 1, 572, 572))    #(2)
x = double_conv(x).size()
# Codeblock 4 output
original                : torch.Size([1, 1, 572, 572])
after first conv        : torch.Size([1, 64, 570, 570])
after second conv       : torch.Size([1, 64, 568, 568])

The above output shows that both the height and width of the input image got reduced by two pixels after each convolution layer, in which it exactly matches with the one written in the first two convolution processes in Figure 1 (572 to 570 and 570 to 568). This reduction is primarily due to the 3×3 kernel size and the absence of padding prior to the convolution operation.


Previously I have mentioned that every single downsampling stage consists of two convolution layers and a single max-pooling layer. At this point we have successfully implemented the two convolutions inside the DoubleConv() class, but we haven’t put the pooling operation just yet. Now what I am going to do here is to create a new class named DownSample() which encapsulates both the convolutions and the pooling. The detailed code for this is shown in Codeblock 5 below.

# Codeblock 5
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.double_conv = DoubleConv(in_channels=in_channels, 
                                      out_channels=out_channels)    #(1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)    #(2)

    def forward(self, x):
        print(f'originaltt: {x.size()}')

        convolved = self.double_conv(x)
        print(f'after double convt: {convolved.size()}')

        maxpooled = self.maxpool(convolved)
        print(f'after poolingtt: {maxpooled.size()}')

        return convolved, maxpooled    #(3)

I think everything inside the __init__() function in the above code is pretty straightforward. First, we use the DoubleConv() class which the number of input and output channels are adjustable (#(1)), and second, we use nn.MaxPool2d() layer with the value of 2 for both the kernel_size and stride parameters (#(2)). With this pooling configuration, the spatial dimension produced is going to be two times smaller.

Taking a closer look at the forward() function, especially at the line marked by #(3), you can see that we return both the output of the double convolution layer (convolved) and the output of the pooling layer (maxpooled). The reason that I do this is because we have two branches in every single downsampling stage: convolved is the tensor to be transferred directly to the upsampling stage in the Decoder, while maxpooled is the one to be brought into the subsequent layer. In the figure below, convolved is highlighted in pink while maxpooled is highlighted in cyan.

Figure 3. The feature maps to be transferred directly to the decoder through Residual-Like connections (pink) and the ones to be fed into the subsequent layers (cyan) [2].
Figure 3. The feature maps to be transferred directly to the decoder through Residual-Like connections (pink) and the ones to be fed into the subsequent layers (cyan) [2].

Next, we are going to check if the DownSample() class we just created works properly by running the Codeblock 6 below. Here I assume that we are initializing the very first downsampling stage, which includes the two convolutions as well as the pooling layer. It is written in Figure 1 that the output of this stage should have the height and width of 284 with 64 channels. And we got it correct.

# Codeblock 6
down_sample = DownSample(in_channels=1, out_channels=64)
x = torch.randn((1, 1, 572, 572))
x = down_sample(x)
# Codeblock 6 output
original                : torch.Size([1, 1, 572, 572])
after double conv       : torch.Size([1, 64, 568, 568])
after pooling           : torch.Size([1, 64, 284, 284])

The Decoder

Now let’s jump into the counterpart of the Encoder, the Decoder. This part of U-Net essentially reverses the downsampling process done by the Encoder. Hence, the processes inside the Decoder are also known as upsampling. The four upsampling stages in the architecture are highlighted in orange in Figure 4 below. Every single upsampling stage comprises of a transpose convolution layer which is then followed by two consecutive standard convolution layers. In the case of U-Net, transpose convolution is responsible for doubling up the spatial dimensions of an image while at the same time it also halves the number of channels. Meanwhile, the use of standard convolution layers in the subsequent step is to maintain the channel count while refining the features.

Figure 4. The four upsampling stages in the Decoder [2].
Figure 4. The four upsampling stages in the Decoder [2].

However, it is actually not as trivial as it says because we do also need to think about how the Residual-Like paths are connected to each upsampling stage. The main idea is actually simple: just concatenate. But what makes it somewhat tricky is that the feature map produced in each downsampling stage is larger than that of the one produced in the upsampling stage. For instance, if you look at the first (uppermost) Residual-Like path in Figure 1, you will see that the feature map from the Encoder has the spatial dimension of 568×568 whereas the one in the Decoder is 392×392. Therefore, in order to make concatenation possible to be done, feature maps from the Encoder need to be cropped so that the size of the two tensors matches. In the Codeblock 7 below, I create a function named crop_image() to do so.

# Codeblock 7
def crop_image(original, expected):    #(1)

    original_dim = original.size()[-1]    #(2)
    expected_dim = expected.size()[-1]    #(3)

    difference = original_dim - expected_dim    #(4)
    padding = difference // 2    #(5)

    cropped = original[:, :, padding:original_dim-padding, padding:original_dim-padding]    #(6)

    return cropped

This function accepts two tensors: original and expected (#(1)). The former refers to the tensor coming from the Encoder, while the latter corresponds to the tensor that is already in the Decoder. So basically, I want the size of the original tensor to be cropped such that it has the same size with the expected tensor. At the line marked with #(2) and #(3), I use simple indexing to get the width of both images. In this case, we don’t need to take the height since it is just the same as the width. Next, we calculate the width difference between the original and the expected images to determine how much of the original image should be cropped (#(4)). At line #(5), we divide difference by two because we want to use the resulting number as the padding, ensuring the cropped region is symmetrical on all sides. Finally, we actually crop the original image using the code at line #(6).


Once the crop_image() function is finished, we can move on to the UpSample() class, which encapsulates the entire upsampling stage. You can see in Codeblock 8 below that I initialize another DoubleConv() block after an nn.ConvTranspose2d() layer. Since we want the resulting image to be spatially twice as large as the input, we need to set both kernel_size and stride to 2 as written at line #(1). Figure 5 illustrates how transpose convolution layer works in case you’re not yet familiar with it.

# Codeblock 8
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv_transpose = nn.ConvTranspose2d(in_channels=in_channels,
                                                 out_channels=out_channels, 
                                                 kernel_size=2, stride=2)    #(1)
        self.double_conv = DoubleConv(in_channels=in_channels,
                                      out_channels=out_channels)
Figure 5. An example of a transpose convolution operation on 2×2 input image with 2×2 kernel and stride 2 (results in 4×4 image) [3].
Figure 5. An example of a transpose convolution operation on 2×2 input image with 2×2 kernel and stride 2 (results in 4×4 image) [3].

To the forward() function of the UpSampling() class, we can see in Codeblock 9 that it accepts two inputs, where x is the tensor from the main flow and connection is the one coming directly from the Encoder (#(1)). Initially, we apply transpose convolution to tensor x (#(2)). Then, we put both connection and x into the crop_image() function such that the spatial dimension of connection is going to be the same as x (#(3)). As the cropping is done, the two tensors are concatenated along the channel dimension. To achieve this, we need to use torch.cat() with dim=1 (#(4)). Finally, we pass the tensor through the DoubleConv() block before returning it (#(5)).

# Codeblock 9
    def forward(self, x, connection):    #(1)
        print(f'x originalttt: {x.size()}')
        print(f'connection originaltt: {connection.size()}')

        x = self.conv_transpose(x)    #(2)
        print(f'x after conv transposett: {x.size()}')

        cropped_connection = crop_image(connection, x)    #(3)
        print(f'connection after croppedt: {x.size()}')

        x = torch.cat([x, cropped_connection], dim=1)    #(4)
        print(f'after concatenationtt: {x.size()}')

        x = self.double_conv(x)    #(5)
        print(f'after double convtt: {x.size()}')

        return x

Now let’s check whether our UpSample() class works properly by running the following codeblock. For this example, I will simulate the first upsampling stage, where the tensor coming from the Bottleneck is denoted as x (#(2)), while the one coming from the last downsampling stage is denoted as connection (#(3)). Here, I set the UpSampling() stage to accept image with 1024 channels and return 512 channels (#(1)). If our implementation is correct, the resulting image should have a height and width of 52.

# Codeblock 10
up_sample = UpSample(1024, 512)    #(1)

x = torch.randn((1, 1024, 28, 28))    #(2)
connection = torch.randn((1, 512, 64, 64))    #(3)

x = up_sample(x, connection)
# Codeblock 10 output
x original                      : torch.Size([1, 1024, 28, 28])    #(1)
connection original             : torch.Size([1, 512, 64, 64])     #(2)
x after conv transpose          : torch.Size([1, 512, 56, 56])     #(3)
connection after cropped        : torch.Size([1, 512, 56, 56])     #(4)
after concatenation             : torch.Size([1, 1024, 56, 56])    #(5)
after double conv               : torch.Size([1, 512, 52, 52])     #(6)

In the output above, we can observe that after the tensor x is processed with transpose convolution, its dimension change from 1×1024×28×28 (#(1)) to 1×512×56×56 (#(3)). The tensor connection initially has dimensions of 1×512×64×64 (#(2)) and is successfully cropped to 1×512×56×56 (#(4)). At this point, since both x and connection have 512 channels, the total number of channels becomes 1024 after they are concatenated (#(5)). Finally, the tensor is transformed to 1×512×52×52 (#(6)), which matches our expectations.


The Complete U-Net Architecture

So far, we have created the DoubleConv(), DownSample(), and UpSample() classes. What we are going to do afterwards is use them all to construct the entire U-Net architecture. Have a look at Codeblock 11 below to see how I do that.

# Codeblock 11
class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):    #(1)
        super().__init__()

        # Encoder    #(2)
        self.downsample_0 = DownSample(in_channels=in_channels, out_channels=64)
        self.downsample_1 = DownSample(in_channels=64, out_channels=128)
        self.downsample_2 = DownSample(in_channels=128, out_channels=256)
        self.downsample_3 = DownSample(in_channels=256, out_channels=512)

        # Bottleneck    #(3)
        self.bottleneck   = DoubleConv(in_channels=512, out_channels=1024)

        # Decoder    #(4)
        self.upsample_0   = UpSample(in_channels=1024, out_channels=512)
        self.upsample_1   = UpSample(in_channels=512, out_channels=256)
        self.upsample_2   = UpSample(in_channels=256, out_channels=128)
        self.upsample_3   = UpSample(in_channels=128, out_channels=64)

        # Output    #(5)
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

At the line marked with #(1), I set the default number of input channels to 1 and the number of classes to 2, matching the original U-Net architecture as explained in the paper (refer to Figure 1 to verify this in the first and last layers). In other words, this model by default accepts a grayscale image and outputs a binary segment. Nevertheless, it is definitely possible to change this number if you want to utilize the model for a more complex segmentation task.

Next, we create the encoder by stacking four downsampling stages (#(2)). Here you need to ensure that the number of input channels of a DownSample() stage is the same as the number of output channels of the previous stage. This principle also applies to the upsampling stages in the Decoder (#(4)). However, while the number of channels typically increases as the network deepens in the Encoder, it decreases in the Decoder as we move towards the output.

For the Bottleneck, we are going to employ a DoubleConv() block as it is essentially just a stack of two convolution layers (#(3)). Finally, we use standard nn.Conv2d() for the output layer with a kernel size of 1×1 and the number of output channels set to the number of classes (#(5)). – By the way if you’re new to image segmentation models, the number of classes essentially corresponds to the number of channels in the output layer, with every single channel represents a specific segment. For instance, if you have segments for the sky, ground, and object, you would need to set the number of output channels to 3: one channel for the sky, one for ground, and one for the object. This can be thought of as classifying each pixel in the output image as belonging to one of these classes.

Now as all U-Net components have been initialized, we can define the flow of the network in the following forward() function.

# Codeblock 12
    def forward(self, x):
        print(f'originaltt: {x.size()}')

        convolved_0, maxpooled_0 = self.downsample_0(x)    #(1)
        print(f'maxpooled_0tt: {maxpooled_0.size()}')

        convolved_1, maxpooled_1 = self.downsample_1(maxpooled_0)    #(2)
        print(f'maxpooled_1tt: {maxpooled_1.size()}')

        convolved_2, maxpooled_2 = self.downsample_2(maxpooled_1)    #(3)
        print(f'maxpooled_2tt: {maxpooled_2.size()}')

        convolved_3, maxpooled_3 = self.downsample_3(maxpooled_2)    #(4)
        print(f'maxpooled_3tt: {maxpooled_3.size()}')

        x = self.bottleneck(maxpooled_3)
        print(f'after bottleneckt: {x.size()}')

        upsampled_0 = self.upsample_0(x, convolved_3)    #(5)
        print(f'upsampled_0tt: {upsampled_0.size()}')

        upsampled_1 = self.upsample_1(upsampled_0, convolved_2)    #(6)
        print(f'upsampled_1tt: {upsampled_1.size()}')

        upsampled_2 = self.upsample_2(upsampled_1, convolved_1)
        print(f'upsampled_2tt: {upsampled_2.size()}')

        upsampled_3 = self.upsample_3(upsampled_2, convolved_0)
        print(f'upsampled_3tt: {upsampled_3.size()}')

        x = self.output(upsampled_3)
        print(f'final outputtt: {x.size()}')

        return x

There are several things I want to emphasize in the code. First, remember that the downsampling stages return two outputs: convolved and maxpooled (at line #(1) to #(4)). Later in the upsampling stages, we pair x with convolved_3 (at line #(5)), upsampled_0 with convolved_2 (at line #(6)), and so on for the remaining stages. This pairing is essential because the output of each upsampling stage is produced by combining the feature map from the previous stage with the corresponding feature map from the Encoder.

We can test the UNet() class we just created using the following codeblock. You can see the output that every single tensor dimension in the flow matches with the dimension written in the original paper (refer to Figure 1). This essentially means that we have correctly implement the U-Net architecture.

# Codeblock 13
unet = UNet()
x = torch.randn((1, 1, 572, 572))
x = unet(x)
# Codeblock 13 output
original                : torch.Size([1, 1, 572, 572])
convolved_0             : torch.Size([1, 64, 568, 568])
maxpooled_0             : torch.Size([1, 64, 284, 284])
convolved_1             : torch.Size([1, 128, 280, 280])
maxpooled_1             : torch.Size([1, 128, 140, 140])
convolved_2             : torch.Size([1, 256, 136, 136])
maxpooled_2             : torch.Size([1, 256, 68, 68])
convolved_3             : torch.Size([1, 512, 64, 64])
maxpooled_3             : torch.Size([1, 512, 32, 32])
after bottleneck        : torch.Size([1, 1024, 28, 28])
upsampled_0             : torch.Size([1, 512, 52, 52])
upsampled_1             : torch.Size([1, 256, 100, 100])
upsampled_2             : torch.Size([1, 128, 196, 196])
upsampled_3             : torch.Size([1, 64, 388, 388])
final output            : torch.Size([1, 2, 388, 388])

In order to display the detailed architecture, including the parameter count, model size, etc., we can use the summary() function we imported earlier.

# Codeblock 14
summary(unet, input_size=(1,1,572,572))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
UNet                                     [1, 2, 388, 388]          --
├─DownSample: 1-1                        [1, 64, 568, 568]         --
│    └─DoubleConv: 2-1                   [1, 64, 568, 568]         --
│    │    └─Conv2d: 3-1                  [1, 64, 570, 570]         576
│    │    └─BatchNorm2d: 3-2             [1, 64, 570, 570]         128
│    │    └─ReLU: 3-3                    [1, 64, 570, 570]         --
│    │    └─Conv2d: 3-4                  [1, 64, 568, 568]         36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 568, 568]         128
│    │    └─ReLU: 3-6                    [1, 64, 568, 568]         --
│    └─MaxPool2d: 2-2                    [1, 64, 284, 284]         --
├─DownSample: 1-2                        [1, 128, 280, 280]        --
│    └─DoubleConv: 2-3                   [1, 128, 280, 280]        --
│    │    └─Conv2d: 3-7                  [1, 128, 282, 282]        73,728
│    │    └─BatchNorm2d: 3-8             [1, 128, 282, 282]        256
│    │    └─ReLU: 3-9                    [1, 128, 282, 282]        --
│    │    └─Conv2d: 3-10                 [1, 128, 280, 280]        147,456
│    │    └─BatchNorm2d: 3-11            [1, 128, 280, 280]        256
│    │    └─ReLU: 3-12                   [1, 128, 280, 280]        --
│    └─MaxPool2d: 2-4                    [1, 128, 140, 140]        --
├─DownSample: 1-3                        [1, 256, 136, 136]        --
│    └─DoubleConv: 2-5                   [1, 256, 136, 136]        --
│    │    └─Conv2d: 3-13                 [1, 256, 138, 138]        294,912
│    │    └─BatchNorm2d: 3-14            [1, 256, 138, 138]        512
│    │    └─ReLU: 3-15                   [1, 256, 138, 138]        --
│    │    └─Conv2d: 3-16                 [1, 256, 136, 136]        589,824
│    │    └─BatchNorm2d: 3-17            [1, 256, 136, 136]        512
│    │    └─ReLU: 3-18                   [1, 256, 136, 136]        --
│    └─MaxPool2d: 2-6                    [1, 256, 68, 68]          --
├─DownSample: 1-4                        [1, 512, 64, 64]          --
│    └─DoubleConv: 2-7                   [1, 512, 64, 64]          --
│    │    └─Conv2d: 3-19                 [1, 512, 66, 66]          1,179,648
│    │    └─BatchNorm2d: 3-20            [1, 512, 66, 66]          1,024
│    │    └─ReLU: 3-21                   [1, 512, 66, 66]          --
│    │    └─Conv2d: 3-22                 [1, 512, 64, 64]          2,359,296
│    │    └─BatchNorm2d: 3-23            [1, 512, 64, 64]          1,024
│    │    └─ReLU: 3-24                   [1, 512, 64, 64]          --
│    └─MaxPool2d: 2-8                    [1, 512, 32, 32]          --
├─DoubleConv: 1-5                        [1, 1024, 28, 28]         --
│    └─Conv2d: 2-9                       [1, 1024, 30, 30]         4,718,592
│    └─BatchNorm2d: 2-10                 [1, 1024, 30, 30]         2,048
│    └─ReLU: 2-11                        [1, 1024, 30, 30]         --
│    └─Conv2d: 2-12                      [1, 1024, 28, 28]         9,437,184
│    └─BatchNorm2d: 2-13                 [1, 1024, 28, 28]         2,048
│    └─ReLU: 2-14                        [1, 1024, 28, 28]         --
├─UpSample: 1-6                          [1, 512, 52, 52]          --
│    └─ConvTranspose2d: 2-15             [1, 512, 56, 56]          2,097,664
│    └─DoubleConv: 2-16                  [1, 512, 52, 52]          --
│    │    └─Conv2d: 3-25                 [1, 512, 54, 54]          4,718,592
│    │    └─BatchNorm2d: 3-26            [1, 512, 54, 54]          1,024
│    │    └─ReLU: 3-27                   [1, 512, 54, 54]          --
│    │    └─Conv2d: 3-28                 [1, 512, 52, 52]          2,359,296
│    │    └─BatchNorm2d: 3-29            [1, 512, 52, 52]          1,024
│    │    └─ReLU: 3-30                   [1, 512, 52, 52]          --
├─UpSample: 1-7                          [1, 256, 100, 100]        --
│    └─ConvTranspose2d: 2-17             [1, 256, 104, 104]        524,544
│    └─DoubleConv: 2-18                  [1, 256, 100, 100]        --
│    │    └─Conv2d: 3-31                 [1, 256, 102, 102]        1,179,648
│    │    └─BatchNorm2d: 3-32            [1, 256, 102, 102]        512
│    │    └─ReLU: 3-33                   [1, 256, 102, 102]        --
│    │    └─Conv2d: 3-34                 [1, 256, 100, 100]        589,824
│    │    └─BatchNorm2d: 3-35            [1, 256, 100, 100]        512
│    │    └─ReLU: 3-36                   [1, 256, 100, 100]        --
├─UpSample: 1-8                          [1, 128, 196, 196]        --
│    └─ConvTranspose2d: 2-19             [1, 128, 200, 200]        131,200
│    └─DoubleConv: 2-20                  [1, 128, 196, 196]        --
│    │    └─Conv2d: 3-37                 [1, 128, 198, 198]        294,912
│    │    └─BatchNorm2d: 3-38            [1, 128, 198, 198]        256
│    │    └─ReLU: 3-39                   [1, 128, 198, 198]        --
│    │    └─Conv2d: 3-40                 [1, 128, 196, 196]        147,456
│    │    └─BatchNorm2d: 3-41            [1, 128, 196, 196]        256
│    │    └─ReLU: 3-42                   [1, 128, 196, 196]        --
├─UpSample: 1-9                          [1, 64, 388, 388]         --
│    └─ConvTranspose2d: 2-21             [1, 64, 392, 392]         32,832
│    └─DoubleConv: 2-22                  [1, 64, 388, 388]         --
│    │    └─Conv2d: 3-43                 [1, 64, 390, 390]         73,728
│    │    └─BatchNorm2d: 3-44            [1, 64, 390, 390]         128
│    │    └─ReLU: 3-45                   [1, 64, 390, 390]         --
│    │    └─Conv2d: 3-46                 [1, 64, 388, 388]         36,864
│    │    └─BatchNorm2d: 3-47            [1, 64, 388, 388]         128
│    │    └─ReLU: 3-48                   [1, 64, 388, 388]         --
├─Conv2d: 1-10                           [1, 2, 388, 388]          130
==========================================================================================
Total params: 31,036,546
Trainable params: 31,036,546
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 167.34
==========================================================================================
Input size (MB): 1.31
Forward/backward pass size (MB): 1992.61
Params size (MB): 124.15
Estimated Total Size (MB): 2118.07
==========================================================================================

This concludes our exploration of the U-Net architecture. Feel free to leave a comment if you notice any mistake, especially regarding the implementation. I would be very happy to hear your feedback.

For your reference, all codes used in this article can be accessed on my GitHub repo, which you can find it here.

Thank you for reading!


References

[1] Olaf Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image Segmentation. Arxiv. https://arxiv.org/pdf/1505.04597 [Accessed August 21, 2024].

[2] Image created originally by author based on [1].

[3] Image created originally by author.

The post Paper Walkthrough: U-Net appeared first on Towards Data Science.

]]>
Paper Walkthrough: Vision Transformer (ViT) https://towardsdatascience.com/paper-walkthrough-vision-transformer-vit-c5dcf76f1a7a/ Tue, 13 Aug 2024 14:17:09 +0000 https://towardsdatascience.com/paper-walkthrough-vision-transformer-vit-c5dcf76f1a7a/ Exploring Vision Transformer (ViT) through PyTorch Implementation from Scratch.

The post Paper Walkthrough: Vision Transformer (ViT) appeared first on Towards Data Science.

]]>
Introduction

Vision Transformer – or commonly abbreviated as ViT – can be perceived as a breakthrough in the field of Computer Vision. When it comes to vision-related tasks, it is commonly addressed using CNN-based models which so far always perform better than any other type of neural networks. It wasn’t until 2020, when a paper titled "An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale" written by Dosovitskiy et al. [1] was published, which offers better capability than CNN.

A single convolution layer in CNN works by extracting features using kernels. Since the size of a kernel is relatively small as compared to the input image, hence it can only capture information contained within that small region. In other words, we can simply say that it focuses on extracting local features. To understand the global context of an image, a stack of multiple convolution layers is required. This problem is addressed by ViT as it directly captures global information from the initial layer. Thus, stacking multiple layers in ViT results in even more comprehensive information extraction.

Figure 1. By stacking multiple convolution layers, CNNs can achieve larger receptive fields, which is essential for capturing the global context of an image [2].
Figure 1. By stacking multiple convolution layers, CNNs can achieve larger receptive fields, which is essential for capturing the global context of an image [2].

The Vision Transformer Architecture

If you have ever learned about transformers, you should be familiar with the terms encoder and decoder. In NLP, sepecifically for tasks like machine translation, the encoder captures the relationships between tokens (i.e., words) in the input sequence, while the decoder is responsible for generating the output sequence. In the case of ViT, we will only need the encoder part, in which it sees every single patch of an image as a token. With the same idea, the encoder is going to find the relationships between patches.

The entire Vision Transformer architecture is displayed in Figure 2. Before we get into the code, I am going to explain each component of the architecture in the following sections.

Figure 2. The Vision Transformer architecture [1].
Figure 2. The Vision Transformer architecture [1].

Patch Flattening & Linear Projection

According to the above figure, we can see that the first step to be done is dividing an image into patches. All these patches arranged to form a sequence. Every single of these patches is then flattened, each forming a single-dimensional array. The sequence of these tokens is then projected into a higher-dimensional space through a linear projection. At this point, we can think the projection result like a word embedding in NLP, i.e., a vector representing a single word. Technically speaking, the linear projection process can be done either with a simple MLP or a convolution layer. I will explain more about this later in the implementation.

Class Token & Positional Embedding

Since we are dealing with classification task, we need to prepend a new token to the projected patch sequence. This token, known as the class token, will aggregate information from the other patches by assigning importance weights to each patch. It is important to note that the patch flattening as well as the linear projection cause the model to lose spatial information. In order to address this issue, positional embedding is added to all tokens – including the class token – such that spatial information can be reintroduced.

Transformer Encoder & MLP Head

At this stage, the tensor is now ready to be fed into the Transformer Encoder block, which the detailed structure can be seen in the right-hand side of Figure 2. This block comprises of four components: layer normalization, multi-head attention, another layer normalization, and an MLP layer. It is also worth noting that there are two residual connections implemented here. The written at the top left corner of the Transformer Encoder block indicates that it will be repeated L times according to the model size to be constructed.

Lastly, we are going to connect the encoder block to the MLP head. Keep in mind that the tensor to be forwarded is only the one that comes out from the class token part. The MLP head itself comprises of a fully-connected layer followed by an output layer, where each neuron in the output layer represents a class available in the dataset.

Vision Transformer Variants

There are three ViT variants proposed in its original paper, namely ViT-B, ViT-L, and ViT-H as shown in Figure 3, where:

  • Layers (L): number of transformer encoders.
  • Hidden size (D): embedding dimensionality to represent a single patch.
  • MLP size: number of neurons in the MLP hidden layer.
  • Heads: number of attention heads in the Multi-Head Attention layer.
  • Params: number of parameters of the model.
Figure 3. The details of the three Vision Transformer variants [1].
Figure 3. The details of the three Vision Transformer variants [1].

In this article, I would like to implement the ViT-Base architecture from scratch using PyTorch. By the way, the module itself actually also provides several pre-trained ViT models [3], namely _vit_b16, _vit_b32, _vit_l16, _vit_l32, and _vit_h14, where the number written as the suffix of these models refers to the patch size used.


Implementing ViT from Scratch

Now let’s begin the fun part, coding! – The very first thing to be done is importing the modules. In this case we will only rely on PyTorch functionalities to construct the ViT architecture. The summary() function loaded from torchinfo will help us displaying the details of the model.

# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

Parameter Configuration

In Codeblock 2 we are going to initialize several variables to configure the model. Here we assume that the number of images to be processed in a single batch is only 1, in which it has the dimension of 3×224×224 (marked by #(1)). The variant that we are going to employ here is ViT-Base, meaning that we need to set the patch size to 16, the number of attention heads to 12, the number of encoders to 12, and the embedding dimension to 768 (#(2)). By using this configuration, the number of patches is going to be 196 (#(3)). This number is obtained by dividing an image of size 224×224 into 16×16 patches, in which it results in 14×14 grid. Thus, we are going to have 196 patches for a single image.

We are also going to use the rate of 0.1 for the dropout layer. (#(4)). It is worth to know that the use of dropout layer is not explicitly mentioned in the paper. However, since the use of these layers can be perceived as a standard practice when it comes to constructing Deep Learning models, hence I just implement it anyway. Additionally, we assume that we have 10 classes in the dataset, so I set the NUM_CLASSES variable accordingly.

# Codeblock 2
#(1)
BATCH_SIZE   = 1
IMAGE_SIZE   = 224
IN_CHANNELS  = 3

#(2)
PATCH_SIZE   = 16
NUM_HEADS    = 12
NUM_ENCODERS = 12
EMBED_DIM    = 768
MLP_SIZE     = EMBED_DIM * 4    # 768*4 = 3072

#(3)
NUM_PATCHES  = (IMAGE_SIZE//PATCH_SIZE) ** 2    # (224//16)**2 = 196

#(4)
DROPOUT_RATE = 0.1
NUM_CLASSES  = 10

Since the main focus of this article is to implement the model, I am not going to talk about how to actually train it. However, if you want to do so, you need to ensure that you have GPU installed on your machine as it can make the training much faster. The Codeblock 3 below is used to check whether Pytorch successfully detects your Nvidia GPU.

# Codeblock 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Codeblock 3 output
cuda

Patch Flattening & Linear Projection Implementation

I’ve mentioned earlier that patch flattening and linear projection operation can be done by using either a simple MLP or a convolution layer. Here I am going to implement both of them in the PatcherUnfold() and PatcherConv() class. Later on, you can choose any of the two to be implemented in the main ViT class.

Let’s start from the PatcherUnfold() first, which the details can be seen in Codeblock 4. Here I employ an nn.Unfold() layer with the kernel_size and stride of PATCH_SIZE (16) at the line marked with #(1). With this configuration, the layer will apply a non-overlapping sliding window to the input image. In every single step, the patch inside will be flattened. Look at Figure 4 below to see the illustration of this operation. In the case of that figure, we apply an unfold operation on an image of size 4×4 using the kernel size and stride of 2.

# Codeblock 4
class PatcherUnfold(nn.Module):
    def __init__(self):
        super().__init__()
        self.unfold = nn.Unfold(kernel_size=PATCH_SIZE, stride=PATCH_SIZE)    #(1)
        self.linear_projection = nn.Linear(in_features=IN_CHANNELS*PATCH_SIZE*PATCH_SIZE, 
                                           out_features=EMBED_DIM)    #(2)
Figure 4. Applying an unfold operation with kernel size and stride 2 on a 4×4 image.
Figure 4. Applying an unfold operation with kernel size and stride 2 on a 4×4 image.

Next, the linear projection operation is done with a standard nn.Linear() layer (#(2)). In order to make the input match with the flattened patch, we need to use IN_CHANNELS*PATCH_SIZE*PATCH_SIZE for the in_features parameter, i.e., 16×16×3 = 768. The projection result dimension is then determined using the out_features param which I set to EMBED_DIM (768). It is important to note that the projection result and the flattened patch have the exact same dimension, as specified by the ViT-B architecture. If you want to implement ViT-L or ViT-H instead, you should change the projection result dimension to 1024 or 1280, respectively, which the size might no longer be the same as the flattened patches.

As the nn.Unfold() and nn.Linear() layer have been initialized, now that we have to connect these layers using the forward() function below. One thing that we need to pay attention to is that the first and second axis of the unfolded tensor need to be swapped using permute() method (#(1)). This is essentially done because we want to treat the flattened patches as a sequence of tokens, similar to how tokens are processed in NLP models. I also print out the shape of every single process done in the codeblock to help you keep track of the array dimension.

# Codeblock 5
    def forward(self, x):
        print(f'originalt: {x.size()}')

        x = self.unfold(x)
        print(f'after unfoldt: {x.size()}')

        x = x.permute(0, 2, 1)    #(1)
        print(f'after permutet: {x.size()}')

        x = self.linear_projection(x)
        print(f'after lin projt: {x.size()}')

        return x

At this point the PatcherUnfold() class has been completed. To check whether it works properly, we can try to feed it with a tensor of random values which simulates a single RGB image of size 224×224.

# Codeblock 6
patcher_unfold = PatcherUnfold()
x = torch.randn(1, 3, 224, 224)
x = patcher_unfold(x)

You can see the output below that our original image has successfully been converted to shape 1×196×768, in which 1 represents the number of images within a single batch, 196 denotes the sequence length (number of patches), and 768 is the embedding dimension.

# Codeblock 6 output
original        : torch.Size([1, 3, 224, 224])
after unfold    : torch.Size([1, 768, 196])
after permute   : torch.Size([1, 196, 768])
after lin proj  : torch.Size([1, 196, 768])

That was the implementation of patch flattening and linear projection with the PatcherUnfold() class. We can actually achieve the same thing using PatcherConv() which the code is shown below.

# Codeblock 7
class PatcherConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=IN_CHANNELS, 
                              out_channels=EMBED_DIM, 
                              kernel_size=PATCH_SIZE, 
                              stride=PATCH_SIZE)

        self.flatten = nn.Flatten(start_dim=2)

    def forward(self, x):
        print(f'originaltt: {x.size()}')

        x = self.conv(x)    #(1)
        print(f'after convtt: {x.size()}')

        x = self.flatten(x)    #(2)
        print(f'after flattentt: {x.size()}')

        x = x.permute(0, 2, 1)    #(3)
        print(f'after permutett: {x.size()}')

        return x

This approach might not seem as straightforward as the previous one because it does not actually flatten the patches. Rather, it uses a convolution layer with EMBED_DIM (768) number of kernels which results in a 14×14 image with 768 channels (#(1)). To obtain the same output dimension as the PatcherUnfold(), we then flatten the spatial dimension (#(2)) and swap the first and second axes of the resulting tensor (#(3)). Look at the output of Codeblock 8 below to see the detailed tensor shape after each step.

# Codeblock 8
patcher_conv = PatcherConv()
x = torch.randn(1, 3, 224, 224)
x = patcher_conv(x)
# Codeblock 8 output
original                : torch.Size([1, 3, 224, 224])
after conv              : torch.Size([1, 768, 14, 14])
after flatten           : torch.Size([1, 768, 196])
after permute           : torch.Size([1, 196, 768])

Additionally, it is worth noting that using nn.Conv2d() in PatcherConv() is more efficient compared to separate unfolding and linear projection in PatcherUnfold() as it combines the two steps into a single operation.


Class Token & Positional Embedding Implementation

After all patches have been projected into an embedding dimension and arranged into a sequence, the next step is to put the class token before the first patch token in the sequence. This process is wrapped together with the positional embedding implementation inside the PosEmbedding() class as shown in Codeblock 9.

# Codeblock 9
class PosEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.class_token = nn.Parameter(torch.randn(size=(BATCH_SIZE, 1, EMBED_DIM)), 
                                        requires_grad=True)    #(1)
        self.pos_embedding = nn.Parameter(torch.randn(size=(BATCH_SIZE, NUM_PATCHES+1, EMBED_DIM)), 
                                          requires_grad=True)    #(2)
        self.dropout = nn.Dropout(p=DROPOUT_RATE)  #(3)

The class token itself is initialized using nn.Parameter(), which is essentially a weight tensor (#(1)). The size of this tensor needs to match with the embedding dimension as well as the batch size, such that it can be concatenated with the existing token sequence. This tensor initially contains random values, which will be updated during the training process. In order to allow it to be updated, we need to set the requires_grad parameter to True. Similarly, we also need to employ nn.Parameter() to create the positional embedding (#(2)), yet with a different shape. In this case we set the sequence dimension to be one token longer than the original sequence to accommodate the class token we just created. Not only that, here I also initialize a dropout layer with the rate that we specified earlier (#(3)).

Afterwards, I am going to connect these layers with the forward() function in Codeblock 10 below. The tensor accepted by this function will be concatenated with the class_token using torch.cat() as written at the line marked by #(1). Next, we will perform element-wise addition between the resulting output and the positional embedding tensor (#(2)) before passing it through the dropout layer (#(3)).

# Codeblock 10
    def forward(self, x):

        class_token = self.class_token
        print(f'class_token dimtt: {class_token.size()}')

        print(f'before concattt: {x.size()}')
        x = torch.cat([class_token, x], dim=1)    #(1)
        print(f'after concattt: {x.size()}')

        x = self.pos_embedding + x    #(2)
        print(f'after pos_embeddingt: {x.size()}')

        x = self.dropout(x)    #(3)
        print(f'after dropouttt: {x.size()}')

        return x

As usual, let’s try to forward-propagate a tensor through this network to see if it works as expected. Keep in mind that the input of pos_embedding model is essentially the tensor produced by either PatcherUnfold() or PatcherConv().

# Codeblock 11
pos_embedding = PosEmbedding()
x = pos_embedding(x)

If we take a closer look at the tensor dimension of each step, we can observe that the size of tensor x is initially 1×196×768. After the class token has been prepended to it, the dimension becomes 1×197×768.

# Codeblock 11 output
class_token dim         : torch.Size([1, 1, 768])
before concat           : torch.Size([1, 196, 768])
after concat            : torch.Size([1, 197, 768])
after pos_embedding     : torch.Size([1, 197, 768])
after dropout           : torch.Size([1, 197, 768])

Transformer Encoder Implementation

If we go back to Figure 2, it can be seen that the Transformer Encoder block comprises of four components. We are going to define all these components inside the TransformerEncoder() class shown below.

# Codeblock 12
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.norm_0 = nn.LayerNorm(EMBED_DIM)    #(1)

        self.multihead_attention = nn.MultiheadAttention(EMBED_DIM,    #(2) 
                                                         num_heads=NUM_HEADS, 
                                                         batch_first=True, 
                                                         dropout=DROPOUT_RATE)

        self.norm_1 = nn.LayerNorm(EMBED_DIM)    #(3)

        self.mlp = nn.Sequential(    #(4)
            nn.Linear(in_features=EMBED_DIM, out_features=MLP_SIZE),    #(5)
            nn.GELU(), 
            nn.Dropout(p=DROPOUT_RATE), 
            nn.Linear(in_features=MLP_SIZE, out_features=EMBED_DIM),    #(6) 
            nn.Dropout(p=DROPOUT_RATE)
        )

The two normalization steps at the line marked with #(1) and #(3) are implemented using nn.LayerNorm(). Keep in mind that layer normalization we use here is different from batch normalization we commonly see in CNNs. Batch normalization works by normalizing the values within a single feature in all samples in a batch. Meanwhile in layer normalization, all features within a single sample will be normalized. Look at the Figure 5 below to better illustrate this concept. In this example, we assume that every row represents a single sample, whereas every column is a single feature. Cells of the same color indicate that their values are normalized together.

Figure 5. Illustration of the difference between batch and layer normalization. Batch normalization normalizes across the batch dimension while layer normalization normalizes across the feature dimension.
Figure 5. Illustration of the difference between batch and layer normalization. Batch normalization normalizes across the batch dimension while layer normalization normalizes across the feature dimension.

Subsequently, we initialize an nn.MultiheadAttention() layer with EMBED_DIM (768) as the input size at the line marked by #(2) in Codeblock 12. The batch_first parameter is set to True to indicate that the batch dimension is placed at the 0-th axis of the input tensor. Generally speaking, multi-head attention itself allows the model to capture various types of relationships between image patches simultaneously. Every single head in multi-head attention focuses on different aspects of these relationships. Later on, this layer accepts three inputs: query, key, and value, which are all required to compute the so-called attention weights. By doing so, this layer can understand how much each patch should attend to every other patch. In other words, this mechanism allows the layer to capture the relationships between two or more patches. The attention mechanism employed in ViT can be perceived as the core of the entire model because this component is essentially the one that allows ViT to surpass the performance of CNNs when it comes to image recognition tasks.

The MLP component inside the Transformer Encoder is constructed using nn.Sequential() (#(4)). Here we implement two consecutive linear layers, each followed by a dropout layer. We also need to put GELU activation function right after the first linear layer. No activation function is used for the second linear layer since its purpose is just to project the tensor back to the original embedding dimension.

Now that it’s time to connect all the layers we just initialized using the codeblock below.

# Codeblock 13
    def forward(self, x):

        residual = x    #(1)
        print(f'residual dimtt: {residual.size()}')

        x = self.norm_0(x)    #(2)
        print(f'after normtt: {x.size()}')

        x = self.multihead_attention(x, x, x)[0]    #(3)
        print(f'after attentiontt: {x.size()}')

        x = x + residual    #(4)
        print(f'after additiontt: {x.size()}')

        residual = x    #(5)
        print(f'residual dimtt: {residual.size()}')

        x = self.norm_1(x)    #(6)
        print(f'after normtt: {x.size()}')

        x = self.mlp(x)    #(7)
        print(f'after mlptt: {x.size()}')

        x = x + residual    #(8)
        print(f'after additiontt: {x.size()}')

        return x

In the above forward() function, we first store the input tensor x into residual variable (#(1)), in which it is used to create the residual connection. Next, we normalize the input tensor (#(2)) prior to feeding it into the multi-head attention layer (#(3)). As I’ve mentioned earlier, this layer takes query, key and value as the input. In this case, tensor x is going to be used as the argument for the three parameters. Notice that I also write [0] at the same line of the code. This is essentially because an nn.MultiheadAttention() object returns two values: attention output and attention weights, where in this case we only need the former. Next, at the line marked with #(4) we perform element-wise addition between the output of the multi-head attention layer and the original input tensor. We then directly update the residual variable with the current tensor x (#(5)) after the first residual operation is performed. The second normalization operation is done at line #(6) before feeding the tensor into the MLP block (#(7)) and performing another element-wise addition operation (#(8)).

We can check whether our Transformer Encoder block implementation is correct using the Codeblock 14 below. Keep in mind that the input of the transformer_encoder model has to be the output produced by PosEmbedding().

# Codeblock 14
transformer_encoder = TransformerEncoder()
x = transformer_encoder(x)
# Codeblock 14 output
residual dim            : torch.Size([1, 197, 768])
after norm              : torch.Size([1, 197, 768])
after attention         : torch.Size([1, 197, 768])
after addition          : torch.Size([1, 197, 768])
residual dim            : torch.Size([1, 197, 768])
after norm              : torch.Size([1, 197, 768])
after mlp               : torch.Size([1, 197, 768])
after addition          : torch.Size([1, 197, 768])

You can see in the above output that there is no change in the tensor dimension after each step. However, if you take a closer look at how the MLP block was constructed in Codeblock 12, you will observe that its hidden layer was expanded to MLP_SIZE (3072) at the line marked by #(5). We then directly project it back to its original dimension, i.e., EMBED_DIM (768) at line #(6).


MLP Head Implementation

The last class we are going to implement is MLPHead(). Just like the MLP layer inside the Transformer Encoder block, MLPHead() also comprises of fully-connected layers, GELU activation function and layer normalization. The entire implementation of this class can be seen in Codeblock 15 below.

# Codeblock 15
class MLPHead(nn.Module):
    def __init__(self):
        super().__init__()

        self.norm = nn.LayerNorm(EMBED_DIM)
        self.linear_0 = nn.Linear(in_features=EMBED_DIM, 
                                  out_features=EMBED_DIM)
        self.gelu = nn.GELU()
        self.linear_1 = nn.Linear(in_features=EMBED_DIM, 
                                  out_features=NUM_CLASSES)    #(1)

    def forward(self, x):
        print(f'originaltt: {x.size()}')

        x = self.norm(x)
        print(f'after normtt: {x.size()}')

        x = self.linear_0(x)
        print(f'after layer_0 mlpt: {x.size()}')

        x = self.gelu(x)
        print(f'after gelutt: {x.size()}')

        x = self.linear_1(x)
        print(f'after layer_1 mlpt: {x.size()}')

        return x

One thing to note is that the second fully-connected layer is essentially the output of the entire ViT architecture (#(1)). Hence, we need to ensure that the number of neurons matches with the number of classes available in the dataset we are going to train the model on. In this case, I assume that we have EMBED_DIM (10) number of classes. Furthermore, it is worth noting that I don’t use a softmax layer at the end since it is already implemented in nn.CrossEntropyLoss() if you want to actually train this model.

In order to test the MLPHead() model, we first need to slice the tensor produced by the Transformer Encoder block as shown at line #(1) in Codeblock 16. This is essentially done because we want to take the 0-th element in the token sequence in which it corresponds to the class token we prepended earlier at the front of the patch token sequence.

# Codeblock 16
x = x[:, 0]    #(1)
mlp_head = MLPHead()
x = mlp_head(x)
# Codeblock 16 output
original                : torch.Size([1, 768])
after norm              : torch.Size([1, 768])
after layer_0 mlp       : torch.Size([1, 768])
after gelu              : torch.Size([1, 768])
after layer_1 mlp       : torch.Size([1, 10])

As the test code in Codeblock 16 is run, now we can see that the final tensor shape is 1×10, which is exactly what we expect.


The Entire ViT Architecture

At this point all ViT components have successfully been created. Hence, we can now use them to construct the entire Vision Transformer architecture. Look at the Codeblock 17 below to see how I do it.

# Codeblock 17
class ViT(nn.Module):
    def __init__(self):
        super().__init__()

        #self.patcher = PatcherUnfold()
        self.patcher = PatcherConv()    #(1) 
        self.pos_embedding = PosEmbedding()
        self.transformer_encoders = nn.Sequential(
            *[TransformerEncoder() for _ in range(NUM_ENCODERS)]    #(2)
            )
        self.mlp_head = MLPHead()

    def forward(self, x):

        x = self.patcher(x)
        x = self.pos_embedding(x)
        x = self.transformer_encoders(x)
        x = x[:, 0]    #(3)
        x = self.mlp_head(x)

        return x

There are several things I want to emphasize regarding the above code. First, at line #(1) we can use either PatcherUnfold() or PatcherConv() as they both have the same role, i.e., to do the patch flattening and linear projection step. In this case, I use the latter for no specific reason. Secondly, the Transformer Encoder block will be repeated NUM_ENCODER (12) times (#(2)) since we are going to implement ViT-Base as stated in Figure 3. Lastly, don’t forget to slice the tensor outputted by the Transformer Encoder since our MLP head will only process the class token part of the output (#(3)).

We can test whether our ViT model works properly using the following code.

# Codeblock 18
vit = ViT().to(device)
x = torch.randn(1, 3, 224, 224).to(device)
print(vit(x).size())

You can see here that the input which the dimension is 1×3×224×224 has been converted to 1×10, which indicates that our model works as expected.

Note: you need to comment out all the prints to make the output looks more concise like this.

# Codeblock 18 output
torch.Size([1, 10])

Additionally, we can also see the detailed structure of the network using the summary() function we imported at the beginning of the code. You can observe that the total number of parameters is around 86 million, which matches the number stated in Figure 3.

# Codeblock 19
summary(vit, input_size=(1,3,224,224))
# Codeblock 19 output
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ViT                                      [1, 10]                   --
├─PatcherConv: 1-1                       [1, 196, 768]             --
│    └─Conv2d: 2-1                       [1, 768, 14, 14]          590,592
│    └─Flatten: 2-2                      [1, 768, 196]             --
├─PosEmbedding: 1-2                      [1, 197, 768]             152,064
│    └─Dropout: 2-3                      [1, 197, 768]             --
├─Sequential: 1-3                        [1, 197, 768]             --
│    └─TransformerEncoder: 2-4           [1, 197, 768]             --
│    │    └─LayerNorm: 3-1               [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-2      [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-3               [1, 197, 768]             1,536
│    │    └─Sequential: 3-4              [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-5           [1, 197, 768]             --
│    │    └─LayerNorm: 3-5               [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-6      [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-7               [1, 197, 768]             1,536
│    │    └─Sequential: 3-8              [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-6           [1, 197, 768]             --
│    │    └─LayerNorm: 3-9               [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-10     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-11              [1, 197, 768]             1,536
│    │    └─Sequential: 3-12             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-7           [1, 197, 768]             --
│    │    └─LayerNorm: 3-13              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-14     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-15              [1, 197, 768]             1,536
│    │    └─Sequential: 3-16             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-8           [1, 197, 768]             --
│    │    └─LayerNorm: 3-17              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-18     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-19              [1, 197, 768]             1,536
│    │    └─Sequential: 3-20             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-9           [1, 197, 768]             --
│    │    └─LayerNorm: 3-21              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-22     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-23              [1, 197, 768]             1,536
│    │    └─Sequential: 3-24             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-10          [1, 197, 768]             --
│    │    └─LayerNorm: 3-25              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-26     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-27              [1, 197, 768]             1,536
│    │    └─Sequential: 3-28             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-11          [1, 197, 768]             --
│    │    └─LayerNorm: 3-29              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-30     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-31              [1, 197, 768]             1,536
│    │    └─Sequential: 3-32             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-12          [1, 197, 768]             --
│    │    └─LayerNorm: 3-33              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-34     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-35              [1, 197, 768]             1,536
│    │    └─Sequential: 3-36             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-13          [1, 197, 768]             --
│    │    └─LayerNorm: 3-37              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-38     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-39              [1, 197, 768]             1,536
│    │    └─Sequential: 3-40             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-14          [1, 197, 768]             --
│    │    └─LayerNorm: 3-41              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-42     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-43              [1, 197, 768]             1,536
│    │    └─Sequential: 3-44             [1, 197, 768]             4,722,432
│    └─TransformerEncoder: 2-15          [1, 197, 768]             --
│    │    └─LayerNorm: 3-45              [1, 197, 768]             1,536
│    │    └─MultiheadAttention: 3-46     [1, 197, 768]             2,362,368
│    │    └─LayerNorm: 3-47              [1, 197, 768]             1,536
│    │    └─Sequential: 3-48             [1, 197, 768]             4,722,432
├─MLPHead: 1-4                           [1, 10]                   --
│    └─LayerNorm: 2-16                   [1, 768]                  1,536
│    └─Linear: 2-17                      [1, 768]                  590,592
│    └─GELU: 2-18                        [1, 768]                  --
│    └─Linear: 2-19                      [1, 10]                   7,690
==========================================================================================
Total params: 86,396,938
Trainable params: 86,396,938
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 173.06
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 102.89
Params size (MB): 231.59
Estimated Total Size (MB): 335.08
==========================================================================================

I think that’s pretty much all about Vision Transformer architecture. Feel free to comment if you spot any mistake in the code.

All codes used in this article is also available in my GitHub repository by the way. Here’s the link to it.

Thanks for reading, and I hope you learn something new today!


References

[1] Alexey Dosovitskiy et al. An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale. Arxiv. https://arxiv.org/pdf/2010.11929 [Accessed August 8, 2024].

[2] Haoning Lin et al. Maritime Semantic Labeling of Optical Remote Sensing Images with Multi-Scale Fully Convolutional Network. Research Gate. https://www.researchgate.net/publication/316950618_Maritime_Semantic_Labeling_of_Optical_Remote_Sensing_Images_with_Multi-Scale_Fully_Convolutional_Network [Accessed August 8, 2024].

[3] Vision Transformer. PyTorch. https://pytorch.org/vision/main/models/vision_transformer.html [Accessed August 8, 2024].


If you enjoyed this article and want to learn about OpenCV and image processing, check out my comprehensive course on Udemy: Image Processing with OpenCV. You’ll gain hands-on experience and master essential techniques.

The post Paper Walkthrough: Vision Transformer (ViT) appeared first on Towards Data Science.

]]>