In [None]:
import torch
import torch.nn as nn
import math

# 1. Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # Add batch dimension
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x is assumed to be (batch_size, seq_len, d_model)
        # Add positional encoding to the input embeddings
        x = x + self.pe[:, :x.size(1)]
        return x

# 2. Multi-Head Self-Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1. Linear projections for Q, K, V
        Q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. Scaled Dot-Product Attention
        # (batch_size, num_heads, seq_len, head_dim) @ (batch_size, num_heads, head_dim, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # 3. Apply mask (e.g., for padding)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # 4. Softmax to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)

        # 5. Multiply by Value to get the weighted sum
        # (batch_size, num_heads, seq_len, seq_len) @ (batch_size, num_heads, seq_len, head_dim)
        # -> (batch_size, num_heads, seq_len, head_dim)
        attended_output = torch.matmul(attention_weights, V)

        # 6. Concatenate heads and apply final linear layer
        # (batch_size, seq_len, num_heads * head_dim) = (batch_size, seq_len, embed_dim)
        attended_output = attended_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.out_linear(attended_output)
        return output

# 3. Position-wise Feed-Forward Network
class PositionwiseFeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(ff_dim, embed_dim)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

# 4. Encoder Layer (combining components)
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate):
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.feed_forward = PositionwiseFeedForward(embed_dim, ff_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):
        # 1. Multi-Head Self-Attention sub-layer
        attn_output = self.self_attn(x, x, x, mask)
        # Add & Norm for self-attention
        x = self.norm1(x + self.dropout(attn_output))

        # 2. Position-wise Feed-Forward sub-layer
        ff_output = self.feed_forward(x)
        # Add & Norm for feed-forward
        x = self.norm2(x + self.dropout(ff_output))
        return x

# 5. Full Encoder (stack of Encoder Layers)
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, ff_dim, dropout_rate, max_len=5000):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim, max_len)
        self.layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads, ff_dim, dropout_rate) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src, src_mask):
        # 1. Token Embedding
        x = self.token_embedding(src)
        # 2. Add Positional Encoding
        x = self.dropout(self.positional_encoding(x))

        # 3. Pass through N encoder layers
        for layer in self.layers:
            x = layer(x, src_mask)
        return x



In [None]:
# Example Usage (conceptual)
vocab_size = 10000
embed_dim = 512
num_layers = 6
num_heads = 8
ff_dim = 2048
dropout_rate = 0.1

encoder = Encoder(vocab_size, embed_dim, num_layers, num_heads, ff_dim, dropout_rate)
#
# Example input: a batch of sentences represented as token IDs
# (batch_size, seq_len)
src_tokens = torch.randint(0, vocab_size, (32, 50))

# Example source mask (1 for actual tokens, 0 for padding)
# (batch_size, 1, 1, seq_len) for broadcasting with attention scores
src_mask = (src_tokens != 0).unsqueeze(1).unsqueeze(2)

encoded_representation = encoder(src_tokens, src_mask)
print(encoded_representation.shape) # Expected: (batch_size, seq_len, embed_dim)

torch.Size([32, 50, 512])


In [None]:
encoded_representation

tensor([[[ 1.0533,  0.7145,  0.1139,  ...,  2.6528, -0.3629, -0.7344],
         [-0.0685,  0.0221, -0.0389,  ...,  0.7509,  0.1320, -1.2755],
         [ 0.3046, -0.2413,  0.8813,  ..., -0.1554, -1.6855, -0.0717],
         ...,
         [-0.7506, -0.6016,  0.9016,  ..., -0.7824,  0.4439, -0.3708],
         [-0.5158,  0.5084,  1.4304,  ...,  0.2439,  0.0417,  0.6725],
         [-2.0073,  0.4224, -1.9239,  ...,  0.2734,  0.4171, -0.0964]],

        [[-0.4678,  0.6313, -1.0587,  ..., -0.3810, -0.5124,  0.2979],
         [ 0.6610, -0.1026,  0.1196,  ...,  0.8262, -0.1782, -1.6137],
         [-0.8347, -0.4412, -0.4669,  ...,  1.9667, -0.1553, -0.4937],
         ...,
         [-1.1995, -0.5709,  0.9507,  ...,  0.3323, -0.5252, -1.6482],
         [ 0.5750, -0.4350,  0.2250,  ..., -1.4591, -1.5233,  0.0812],
         [-0.4801,  0.6986,  0.2974,  ...,  1.0846,  1.0973, -0.5184]],

        [[-1.4572,  0.7231,  1.1838,  ...,  1.1310,  0.1103,  0.1918],
         [ 0.2012,  0.2403,  0.9877,  ...,  3