In [1]:
from collections import defaultdict

# Function to learn BPE from a corpus
def learn_bpe(corpus, num_merges=3):
    vocab = defaultdict(int)

    # Initialize vocabulary with character pairs for each word in the corpus
    for sentence in corpus.split('.'):
        words = sentence.strip().split()
        for word in words:
            chars = ['<'] + list(word) + ['>']  # Add start and end markers
            for i in range(len(chars) - 1):
                pair = (chars[i], chars[i+1])
                vocab[pair] += 1

    merges = []
    for _ in range(num_merges):
        if not vocab:
            break
        # Find the most frequent pair
        most_frequent = max(vocab, key=lambda x: vocab[x])
        merges.append(most_frequent)

        # Merge the pair and update vocabulary
        new_char = ''.join(most_frequent)
        new_vocab = defaultdict(int)
        for pair in vocab:
            count = vocab[pair]
            if pair == most_frequent:
                continue
            new_pair = list(pair)
            if new_pair[0] == most_frequent[0] and new_pair[1] == most_frequent[1]:
                new_pair[0] = new_char
                new_pair.pop(1)
            new_vocab[tuple(new_pair)] += count
        vocab = new_vocab

    return merges

# Function to apply BPE to a given word
def apply_bpe(text, merges):
    chars = ['<'] + list(text) + ['>']  # Add boundary markers to the word
    for merge in reversed(merges):  # Apply the merges in reverse order
        merged = ''.join(merge)
        new_chars = []
        i = 0
        while i < len(chars) - 1:
            if (chars[i], chars[i+1]) == merge:  # Merge if pair matches
                new_chars.append(merged)
                i += 2
            else:
                new_chars.append(chars[i])
                i += 1
        if i < len(chars):
            new_chars.append(chars[-1])
        chars = new_chars

    return chars

# Example usage
corpus = "ab bc bcd cde"
merges = learn_bpe(corpus, num_merges=3)  # Learn merges from the corpus
print("Learned Merges:", merges)

# Apply BPE to a specific word
bpe_representation = apply_bpe("bcd", merges)
print("BPE Representation for 'bcd':", bpe_representation)


Learned Merges: [('<', 'b'), ('b', 'c'), ('c', 'd')]
BPE Representation for 'bcd': ['<b', 'cd', '>']
