In [None]:
# Installing all dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Function for Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # shape (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_ff, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_ff, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)[0]
        tgt = self.norm1(tgt + self.dropout1(tgt2))

        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask)[0]
        tgt = self.norm2(tgt + self.dropout2(tgt2))

        tgt2 = self.linear2(F.relu(self.linear1(tgt)))
        tgt = self.norm3(tgt + self.dropout3(tgt2))
        return tgt

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, d_model, nhead, dim_ff, vocab_size, dropout=0.1, max_len=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
            for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        tgt = self.embedding(tgt)
        tgt = self.pos_encoder(tgt)
        for layer in self.layers:
            tgt = layer(tgt, memory, tgt_mask, memory_mask)
        return self.output_layer(tgt)

# Setting Parameters
d_model = 128
nhead = 4
dim_ff = 512
num_layers = 2
vocab_size = 5000
seq_len = 10
batch_size = 4

# Instantiating model
decoder = TransformerDecoder(num_layers, d_model, nhead, dim_ff, vocab_size)

# Dummy inputs
tgt = torch.randint(0, vocab_size, (seq_len, batch_size))
memory = torch.rand(seq_len, batch_size, d_model)

# Forward pass
out = decoder(tgt, memory)
print(out.shape)
out

torch.Size([10, 4, 5000])


tensor([[[-0.2575,  0.2022,  0.5841,  ...,  0.1697,  0.0914, -0.2776],
         [ 0.8105,  0.0042, -0.1367,  ...,  0.1824, -0.4727, -0.3356],
         [ 0.8892, -0.1862, -0.4007,  ...,  0.0032,  0.4972, -0.5033],
         [-0.7009, -0.3311,  0.2239,  ...,  0.1335, -0.4804, -0.8168]],

        [[ 0.1692, -1.4079,  0.4325,  ...,  0.6126, -0.3137,  0.7255],
         [ 0.8316, -0.6884,  0.3590,  ...,  0.4932, -0.6200,  0.1858],
         [ 0.1351, -0.7370, -0.9747,  ..., -0.3245, -0.2708,  0.2289],
         [-0.1760, -0.5264, -0.3040,  ..., -0.6869, -0.6682,  0.2687]],

        [[ 0.0588,  0.2074, -0.1735,  ...,  0.1199, -0.0691, -0.6977],
         [ 0.2403,  0.3130,  0.1436,  ..., -0.2068, -0.0404, -0.2347],
         [ 0.5008, -0.0039,  0.0048,  ...,  0.0374, -0.5664,  0.6932],
         [ 0.4907,  0.5247,  0.6455,  ...,  0.3456, -0.2196, -1.2574]],

        ...,

        [[ 0.1468,  0.5852,  0.2892,  ..., -0.1733,  0.1776, -0.6742],
         [-0.5668,  0.1376, -0.6314,  ..., -0.4600, -0.80