In [2]:
train_data = [
    [("the", "DET"), ("cat", "NOUN"), ("sat", "VERB")],
    [("the", "DET"), ("dog", "NOUN"), ("barked", "VERB")],
    [("a", "DET"), ("dog", "NOUN"), ("sat", "VERB")],
]

In [3]:
from collections import defaultdict
import math

transition = defaultdict(lambda: defaultdict(int))
emission = defaultdict(lambda: defaultdict(int))
start_prob = defaultdict(int)
tag_counts = defaultdict(int)

for sentence in train_data:
    prev_tag = None
    for i, (word, tag) in enumerate(sentence):
        tag_counts[tag] += 1
        emission[tag][word] += 1

        if i == 0:
            start_prob[tag] += 1
        else:
            transition[prev_tag][tag] += 1
        prev_tag = tag

def normalize(d):
    total = sum(d.values())
    return {k: v / total for k, v in d.items()}

start_prob = normalize(start_prob)
for tag in emission:
    emission[tag] = normalize(emission[tag])
for prev in transition:
    transition[prev] = normalize(transition[prev])


In [4]:
def viterbi(sentence, states, start_p, trans_p, emit_p):
    V = [{}]
    path = {}

    for state in states:
        V[0][state] = start_p.get(state, 0) * emit_p[state].get(sentence[0], 1e-6)
        path[state] = [state]

    for t in range(1, len(sentence)):
        V.append({})
        new_path = {}

        for curr_state in states:
            max_prob, prev_state = max(
                (V[t - 1][y0] * trans_p[y0].get(curr_state, 1e-6) * emit_p[curr_state].get(sentence[t], 1e-6), y0)
                for y0 in states
            )
            V[t][curr_state] = max_prob
            new_path[curr_state] = path[prev_state] + [curr_state]

        path = new_path

    n = len(sentence) - 1
    prob, final_state = max((V[n][y], y) for y in states)
    return path[final_state]


In [5]:
test_sentence = ["a", "cat", "barked"]
states = list(tag_counts.keys())

predicted_tags = viterbi(test_sentence, states, start_prob, transition, emission)
print("Sentence:", test_sentence)
print("Predicted Tags:", predicted_tags)


Sentence: ['a', 'cat', 'barked']
Predicted Tags: ['DET', 'NOUN', 'VERB']
