Diving into the Transformers architecture and what makes them unbeatable at language tasks

Image by the author

In the rapidly evolving landscape of artificial intelligence and machine learning, one innovation stands out for its profound impact on how we process, understand, and generate data: Transformers. Transformers have revolutionized the field of natural language processing (NLP) and beyond, powering some of today’s most advanced AI applications. But what exactly are Transformers, and how do they manage to transform data in such groundbreaking ways? This article demystifies the inner workings of Transformer models, focusing on the encoder architecture. We will start by going through the implementation of a Transformer encoder in Python, breaking down its main components. Then, we will visualize how Transformers process and adapt input data during training.

While this blog doesn’t cover every architectural detail, it provides an implementation and an overall understanding of the transformative power of Transformers. For an in-depth explanation of Transformers, I suggest you look at the excellent Stanford CS224-n course.

What is a Transformer encoder architecture?

The Transformer model from Attention Is All You Need

This picture shows the original Transformer architecture, combining an encoder and a decoder for sequence-to-sequence language tasks.

In this article, we will focus on the encoder architecture (the red block on the picture). This is what the popular BERT model is using under the hood: the primary focus is on understanding and representing the data, rather than generating sequences. It can be used for a variety of applications: text classification, named-entity recognition (NER), extractive question answering, etc.

So, how is the data actually transformed by this architecture? We will explain each component in detail, but here is an overview of the process.

That is important to understand that the Transformer architecture transforms the embedding vectors by mapping them from one representation in a high-dimensional space to another within the same space, applying a series of complex transformations.

Implementing an encoder architecture in Python

The Positional Encoder layer

Unlike RNN models, the attention mechanism makes no use of the order of the input sequence. The PositionalEncoder class adds positional encodings to the input embeddings, using two mathematical functions: cosine and sine.

Positional encoding matrix definition from Attention Is All You Need

Note that positional encodings don’t contain trainable parameters: there are the results of deterministic computations, which makes this method very tractable. Also, sine and cosine functions take values between -1 and 1 and have useful periodicity properties to help the model learn patterns about the relative positions of words.

class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_length):
super(PositionalEncoder, self).__init__()
self.d_model = d_model
self.max_length = max_length

# Initialize the positional encoding matrix
pe = torch.zeros(max_length, d_model)

position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))

# Calculate and assign position encodings to the matrix
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0)

def forward(self, x):
x = x + self.pe[:, :x.size(1)] # update embeddings
return x

Multi-Head Self-Attention

The self-attention mechanism is the key component of the encoder architecture. Let’s ignore the “multi-head” for now. Attention is a way to determine for each token (i.e. each embedding) the relevance of all other embeddings to that token, to obtain a more refined and contextually relevant encoding.

How does“it” pay attention to other words of the sequence? (The Illustrated Transformer)

There are 3 steps in the self-attention mechanism.

Mathematically, this corresponds to the following formula.

The Attention Mechanism from Attention Is All You Need

What does “multi-head” mean? Basically, we can apply the described self-attention mechanism process several times, in parallel, and concatenate and project the outputs. This allows each head to focus on different semantic aspects of the sentence.

We start by defining the number of heads, the dimension of the embeddings (d_model), and the dimension of each head (head_dim). We also initialize the Q, K, and V matrices (linear layers), and the final projection layer.

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = d_model // num_heads

self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)

When using multi-head attention, we apply each attention head with a reduced dimension (head_dim instead of d_model) as in the original paper, making the total computational cost similar to a one-head attention layer with full dimensionality. Note this is a logical split only. What makes multi-attention so powerful is it can still be represented via a single matrix operation, making computations very efficient on GPUs.

def split_heads(self, x, batch_size):
# Split the sequence embeddings in x across the attention heads
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return x.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)

We compute the attention scores and use a mask to avoid using attention on padded tokens. We apply a softmax activation to make these scores probabilities.

def compute_attention(self, query, key, mask=None):
# Compute dot-product attention scores
# dimensions of query and key are (batch_size * num_heads, seq_length, head_dim)
scores = query @ key.transpose(-2, -1) / math.sqrt(self.head_dim)
# Now, dimensions of scores is (batch_size * num_heads, seq_length, seq_length)
if mask is not None:
scores = scores.view(-1, scores.shape[0] // self.num_heads, mask.shape[1], mask.shape[2]) # for compatibility
scores = scores.masked_fill(mask == 0, float('-1e20')) # mask to avoid attention on padding tokens
scores = scores.view(-1, mask.shape[1], mask.shape[2]) # reshape back to original shape
# Normalize attention scores into attention weights
attention_weights = F.softmax(scores, dim=-1)

return attention_weights

The forward attribute performs the multi-head logical split and computes the attention weights. Then, we get the output by multiplying these weights by the values. Finally, we reshape the output and project it with a linear layer.

def forward(self, query, key, value, mask=None):
batch_size = query.size(0)

query = self.split_heads(self.query_linear(query), batch_size)
key = self.split_heads(self.key_linear(key), batch_size)
value = self.split_heads(self.value_linear(value), batch_size)

attention_weights = self.compute_attention(query, key, mask)

# Multiply attention weights by values, concatenate and linearly project outputs
output = torch.matmul(attention_weights, value)
output = output.view(batch_size, self.num_heads, -1, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
return self.output_linear(output)

The Encoder Layer

This is the main component of the architecture, which leverages multi-head self-attention. We first implement a simple class to perform a feed-forward operation through 2 dense layers.

class FeedForwardSubLayer(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForwardSubLayer, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()

def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))

We can now code the logic for the encoder layer. We start by applying self-attention to the input, which gives a vector of the same dimension. We then use our mini feed-forward network with Layer Norm layers. Note that we also use skip connections before applying normalization.

class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForwardSubLayer(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask):
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output)) # skip connection and normalization
ff_output = self.feed_forward(x)
return self.norm2(x + self.dropout(ff_output)) # skip connection and normalization

Putting Everything Together

It’s time to create our final model. We pass our data through an embedding layer. This transforms our raw tokens (integers) into a numerical vector. We then apply our positional encoder and several (num_layers) encoder layers.

class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length):
super(TransformerEncoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoder(d_model, max_sequence_length)
self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

def forward(self, x, mask):
x = self.embedding(x)
x = self.positional_encoding(x)
for layer in self.layers:
x = layer(x, mask)
return x

We also create a ClassifierHead class which is used to transform the final embedding into class probabilities for our classification task.

class ClassifierHead(nn.Module):
def __init__(self, d_model, num_classes):
super(ClassifierHead, self).__init__()
self.fc = nn.Linear(d_model, num_classes)

def forward(self, x):
logits = self.fc(x[:, 0, :]) # first token corresponds to the classification token
return F.softmax(logits, dim=-1)

Note that the dense and softmax layers are only applied on the first embedding (corresponding to the first token of our input sequence). This is because when tokenizing the text, the first token is the [CLS] token which stands for “classification.” The [CLS] token is designed to aggregate the entire sequence’s information into a single embedding vector, serving as a summary representation that can be used for classification tasks.

Note: the concept of including a [CLS] token originates from BERT, which was initially trained on tasks like next-sentence prediction. The [CLS] token was inserted to predict the likelihood that sentence B follows sentence A, with a [SEP] token separating the 2 sentences. For our model, the [SEP] token simply marks the end of the input sentence, as shown below.

[CLS] Token in BERT Architecture (All About AI)

When you think about it, it’s really mind-blowing that this single [CLS] embedding is able to capture so much information about the entire sequence, thanks to the self-attention mechanism’s ability to weigh and synthesize the importance of every piece of the text in relation to each other.

Training and visualization

Hopefully, the previous section gives you a better understanding of how our Transformer model transforms the input data. We will now write our training pipeline for our binary classification task using the IMDB dataset (movie reviews). Then, we will visualize the embedding of the [CLS] token during the training process to see how our model transformed it.

We first define our hyperparameters, as well as a BERT tokenizer. In the GitHub repository, you can see that I also coded a function to select a subset of the dataset with only 1200 train and 200 test examples.

num_classes = 2 # binary classification
d_model = 256 # dimension of the embedding vectors
num_heads = 4 # number of heads for self-attention
num_layers = 4 # number of encoder layers
d_ff = 512. # dimension of the dense layers in the encoder layers
sequence_length = 256 # maximum sequence length
dropout = 0.4 # dropout to avoid overfitting
num_epochs = 20
batch_size = 32

loss_function = torch.nn.CrossEntropyLoss()

dataset = load_dataset("imdb")
dataset = balance_and_create_dataset(dataset, 1200, 200) # check GitHub repo

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', model_max_length=sequence_length)

You can try to use the BERT tokenizer on one of the sentences:


Every sequence should start with the token 101, corresponding to [CLS], followed by some non-zero integers and padded with zeros if the sequence length is smaller than 256. Note that these zeros are ignored during the self-attention computation using our “mask”.

tokenized_datasets = dataset.map(encode_examples, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=True)

vocab_size = tokenizer.vocab_size

encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)
classifier = ClassifierHead(d_model, num_classes)

optimizer = torch.optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=1e-4)

We can now write our train function:

def train(dataloader, encoder, classifier, optimizer, loss_function, num_epochs):
for epoch in range(num_epochs):
# Collect and store embeddings before each epoch starts for visualization purposes (check repo)
all_embeddings, all_labels = collect_embeddings(encoder, dataloader)
reduced_embeddings = visualize_embeddings(all_embeddings, all_labels, epoch, show=False)
dic_embeddings[epoch] = [reduced_embeddings, all_labels]

correct_predictions = 0
total_predictions = 0
for batch in tqdm(dataloader, desc="Training"):
input_ids = batch['input_ids']
attention_mask = batch['attention_mask'] # indicate where padded tokens are
# These 2 lines make the attention_mask a matrix instead of a vector
attention_mask = attention_mask.unsqueeze(-1)
attention_mask = attention_mask & attention_mask.transpose(1, 2)
labels = batch['label']
output = encoder(input_ids, attention_mask)
classification = classifier(output)
loss = loss_function(classification, labels)
preds = torch.argmax(classification, dim=1)
correct_predictions += torch.sum(preds == labels).item()
total_predictions += labels.size(0)

epoch_accuracy = correct_predictions / total_predictions
print(f'Epoch {epoch} Training Accuracy: {epoch_accuracy:.4f}')

You can find the collect_embeddings and visualize_embeddings functions in the GitHub repo. They store the [CLS] token embedding for each sentence of the training set, apply a dimensionality reduction technique called t-SNE to make them 2D vectors (instead of 256-dimensional vectors), and save an animated plot.

Let’s visualize the results.

Projected [CLS] embeddings for each training point (blue corresponds to positive sentences, red corresponds to negative sentences)

Observing the plot of projected [CLS] embeddings for each training point, we can see the clear distinction between positive (blue) and negative (red) sentences after a few epochs. This visual shows the remarkable capability of the Transformer architecture to adapt embeddings over time and highlights the power of the self-attention mechanism. The data is transformed in such a way that embeddings for each class are well separated, thereby significantly simplifying the task for the classifier head.


As we conclude our exploration of the Transformer architecture, it’s evident that these models are adept at tailoring data to a given task. With the use of positional encoding and multi-head self-attention, Transformers go beyond mere data processing: they interpret and understand information with a level of sophistication previously unseen. The ability to dynamically weigh the relevance of different parts of the input data allows for a more nuanced understanding and representation of the input text. This enhances performance across a wide array of downstream tasks, including text classification, question answering, named entity recognition, and more.

Now that you have a better understanding of the encoder architecture, you are ready to delve into decoder and encoder-decoder models, which are very similar to what we have just explored. Decoders play a pivotal role in generative tasks and are at the core of the popular GPT models.


