In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------
# Step 1: Define the Attention class
# ------------------------------
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs):
        # hidden: [batch_size, hidden_dim]
        # encoder_outputs: [batch_size, seq_len, hidden_dim]
        batch_size = encoder_outputs.shape[0]
        seq_len = encoder_outputs.shape[1]

        # Repeat hidden state across all time steps
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)

        # Compute energy scores
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch, seq_len, hidden_dim]
        energy = energy.permute(0, 2, 1)  # [batch, hidden_dim, seq_len]

        # v: [hidden_dim] â†’ [batch, 1, hidden_dim]
        v = self.v.repeat(batch_size, 1).unsqueeze(1)

        # Compute raw attention scores
        attention_scores = torch.bmm(v, energy).squeeze(1)  # [batch, seq_len]

        # Softmax normalization to get weights
        attention_weights = F.softmax(attention_scores, dim=1)  # [batch, seq_len]

        # Weighted sum (context vector)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # [batch, 1, hidden_dim]
        return context, attention_weights


# ------------------------------
# Step 2: Create sample input data
# ------------------------------
torch.manual_seed(0)
batch_size = 1
seq_len = 4
hidden_dim = 8

# Simulated encoder outputs (4 time steps, each of dimension 8)
encoder_outputs = torch.randn(batch_size, seq_len, hidden_dim)

# Simulated decoder hidden state
decoder_hidden = torch.randn(batch_size, hidden_dim)

# ------------------------------
# Step 3: Initialize and run Attention
# ------------------------------
attention = Attention(hidden_dim)
context, attn_weights = attention(decoder_hidden, encoder_outputs)

# ------------------------------
# Step 4: Display results
# ------------------------------
print("Encoder Outputs:\n", encoder_outputs)
print("\nDecoder Hidden State:\n", decoder_hidden)
print("\nAttention Weights:\n", attn_weights)
print("\nContext Vector:\n", context)


Encoder Outputs:
 tensor([[[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160,
          -2.1152],
         [ 0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168,
          -0.2473],
         [-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,
           1.8530],
         [ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
          -0.8437]]])

Decoder Hidden State:
 tensor([[-0.5663,  0.3731, -0.8920, -1.5091,  0.3704,  1.4565,  0.9398,  0.7748]])

Attention Weights:
 tensor([[0.3385, 0.1583, 0.2507, 0.2526]], grad_fn=<SoftmaxBackward0>)

Context Vector:
 tensor([[[-0.4796, -1.1630,  0.0688,  0.1472,  0.8072,  0.4410,  0.2233,
          -0.5037]]], grad_fn=<BmmBackward0>)
