from __future__ import annotations
from collections.abc import Sequence
from typing import Final, ReadOnly, TypeAlias
# Type aliases for better readability and type safety
ByteToken: TypeAlias = bytes
MergePair: TypeAlias = tuple[ByteToken, ByteToken]
Vocab: TypeAlias = dict[TokenId, ByteToken]
merges: ReadOnly[Sequence[MergePair]],
self.pattern: Final[str] = pattern
self.vocab: Final[Vocab] = dict(vocab)
self.merges: Final[tuple[MergePair, ...]] = tuple(merges)
# Reverse lookup from byte tokens to token IDs
self.byte_to_token_id: Final[dict[ByteToken, TokenId]] = {
v: k for k, v in self.vocab.items()
# Pre-compile regex with optimization flags
self._regex: Final[re.Pattern[str]] = re.compile(
pattern, re.MULTILINE | re.DOTALL
# Build merge lookup table for O(1) access
self._merge_ranks: Final[dict[MergePair, int]] = {
pair: idx for idx, pair in enumerate(self.merges)
def _apply_merges(self, tokens: list[ByteToken]) -> list[ByteToken]:
# 1) Scan current sequence to find the best (lowest-rank) adjacent pair.
ranks = self._merge_ranks
# tuple[bytes, bytes] is hashable; dict.get is O(1)
rank = ranks.get((tokens[i], tokens[i + 1]), float("inf"))
# 2) If none found, we are done.
# 3) Merge the best pair once at best_pos.
merged = tokens[best_pos] + tokens[best_pos + 1]
tokens[best_pos : best_pos + 2] = [merged]
n -= 1 # sequence shrinks by one
def encode(self, text: str) -> list[TokenId]:
token_ids: list[TokenId] = []
for match in self._regex.finditer(text):
# Direct byte array creation
byte_tokens = [bytes([b]) for b in piece.encode("utf-8")]
byte_tokens = self._apply_merges(byte_tokens)
token_ids.extend(self.byte_to_token_id[token] for token in byte_tokens)
def decode(self, encoding: Sequence[TokenId]) -> str:
byte_sequence = b"".join(self.vocab[token_id] for token_id in encoding)
return byte_sequence.decode("utf-8")
# ----------------------------------------------------------------
def find_most_frequent_pair(tokens: list[list[bytes]]) -> tuple[bytes, bytes] | None:
pair_counts: dict[tuple[bytes, bytes], int] = {}
for token_list in tokens:
for i in range(len(token_list) - 1):
pair = (token_list[i], token_list[i + 1])
pair_counts[pair] = pair_counts.get(pair, 0) + 1
max_count = max(pair_counts.values())
# Get all pairs with the maximum count and sort them for deterministic tie-breaking
max_pairs = [pair for pair, count in pair_counts.items() if count == max_count]
# Sort pairs lexicographically to ensure deterministic results
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
"""Given the path to an input corpus, run train a BPE tokenizer and
output its vocabulary and merges.
corpus (str): Path to BPE tokenizer training data.
vocab_size (int): Total number of items in the tokenizer's vocabulary.
pattern (str): A regex pattern to pre-tokenize the input text.
tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
representing that <token1> was merged with <token2>.
Merges are ordered by order of creation.
# Read corpus from file path
with open(corpus, "r", encoding="utf-8") as f:
regex = re.compile(pattern, re.MULTILINE | re.DOTALL)
# Pre-tokenize text and convert to byte tokens
tokens: list[list[bytes]] = []
for match in regex.finditer(text):
byte_tokens = [bytes([b]) for b in piece.encode("utf-8")]
tokens.append(byte_tokens)
vocab: dict[int, bytes] = {}
# Assign token IDs to base vocabulary (all possible single bytes 0-255)
vocab[idx] = bytes([idx])
merges: list[tuple[bytes, bytes]] = []
next_token_id = len(vocab)
while next_token_id < vocab_size:
# Find most frequent adjacent pair
most_frequent_pair = find_most_frequent_pair(tokens)
if most_frequent_pair is None:
merges.append(most_frequent_pair)
# Create new merged token
merged_token = most_frequent_pair[0] + most_frequent_pair[1]
vocab[next_token_id] = merged_token
# Replace all occurrences of the pair with the new token
for token_list in tokens:
while i < len(token_list) - 1:
if (token_list[i], token_list[i + 1]) == most_frequent_pair:
token_list[i : i + 2] = [merged_token]