Tokenizer training and inference code

1. Tokenizer training and inference code

1
from __future__ import annotations
2
3
from collections.abc import Sequence
4
from typing import Final, ReadOnly, TypeAlias
5
6
import regex as re
7
8
# Type aliases for better readability and type safety
9
TokenId: TypeAlias = int
10
ByteToken: TypeAlias = bytes
11
MergePair: TypeAlias = tuple[ByteToken, ByteToken]
12
Vocab: TypeAlias = dict[TokenId, ByteToken]
13
14
15
class BPETokenizer:
16
def __init__(
17
self,
18
pattern: str,
19
vocab: ReadOnly[Vocab],
20
merges: ReadOnly[Sequence[MergePair]],
21
) -> None:
22
self.pattern: Final[str] = pattern
23
self.vocab: Final[Vocab] = dict(vocab)
24
self.merges: Final[tuple[MergePair, ...]] = tuple(merges)
25
26
# Reverse lookup from byte tokens to token IDs
27
self.byte_to_token_id: Final[dict[ByteToken, TokenId]] = {
28
v: k for k, v in self.vocab.items()
29
}
30
31
# Pre-compile regex with optimization flags
32
self._regex: Final[re.Pattern[str]] = re.compile(
33
pattern, re.MULTILINE | re.DOTALL
34
)
35
36
# Build merge lookup table for O(1) access
37
self._merge_ranks: Final[dict[MergePair, int]] = {
38
pair: idx for idx, pair in enumerate(self.merges)
39
}
40
41
def _apply_merges(self, tokens: list[ByteToken]) -> list[ByteToken]:
42
n = len(tokens)
43
if n < 2:
44
return tokens
45
46
while True:
47
best_pos = -1
48
best_rank = float("inf")
49
50
# 1) Scan current sequence to find the best (lowest-rank) adjacent pair.
51
i = 0
52
ranks = self._merge_ranks
53
while i < n - 1:
54
# tuple[bytes, bytes] is hashable; dict.get is O(1)
55
rank = ranks.get((tokens[i], tokens[i + 1]), float("inf"))
56
if rank < best_rank:
57
best_rank = rank
58
best_pos = i
59
i += 1
60
61
# 2) If none found, we are done.
62
if best_pos == -1:
63
break
64
65
# 3) Merge the best pair once at best_pos.
66
merged = tokens[best_pos] + tokens[best_pos + 1]
67
tokens[best_pos : best_pos + 2] = [merged]
68
n -= 1 # sequence shrinks by one
69
70
return tokens
71
72
def encode(self, text: str) -> list[TokenId]:
73
if not text:
74
return []
75
76
token_ids: list[TokenId] = []
77
78
for match in self._regex.finditer(text):
79
piece = match.group(0)
80
# Direct byte array creation
81
byte_tokens = [bytes([b]) for b in piece.encode("utf-8")]
82
83
# Apply merge algorithm
84
byte_tokens = self._apply_merges(byte_tokens)
85
86
token_ids.extend(self.byte_to_token_id[token] for token in byte_tokens)
87
88
return token_ids
89
90
def decode(self, encoding: Sequence[TokenId]) -> str:
91
if not encoding:
92
return ""
93
94
byte_sequence = b"".join(self.vocab[token_id] for token_id in encoding)
95
return byte_sequence.decode("utf-8")
96
97
98
# ----------------------------------------------------------------
99
100
101
def find_most_frequent_pair(tokens: list[list[bytes]]) -> tuple[bytes, bytes] | None:
102
pair_counts: dict[tuple[bytes, bytes], int] = {}
103
104
for token_list in tokens:
105
for i in range(len(token_list) - 1):
106
pair = (token_list[i], token_list[i + 1])
107
pair_counts[pair] = pair_counts.get(pair, 0) + 1
108
109
if not pair_counts:
110
return None
111
112
# Find the maximum count
113
max_count = max(pair_counts.values())
114
115
# Get all pairs with the maximum count and sort them for deterministic tie-breaking
116
max_pairs = [pair for pair, count in pair_counts.items() if count == max_count]
117
118
# Sort pairs lexicographically to ensure deterministic results
119
return max(max_pairs)
120
121
122
def train_bpe(
123
corpus: str,
124
vocab_size: int,
125
pattern: str,
126
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
127
"""Given the path to an input corpus, run train a BPE tokenizer and
128
output its vocabulary and merges.
129
130
Args:
131
corpus (str): Path to BPE tokenizer training data.
132
vocab_size (int): Total number of items in the tokenizer's vocabulary.
133
pattern (str): A regex pattern to pre-tokenize the input text.
134
135
Returns:
136
tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
137
vocab:
138
The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
139
to bytes (token bytes)
140
merges:
141
BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
142
representing that <token1> was merged with <token2>.
143
Merges are ordered by order of creation.
144
"""
145
146
# Read corpus from file path
147
with open(corpus, "r", encoding="utf-8") as f:
148
text = f.read()
149
150
# Compile regex pattern
151
regex = re.compile(pattern, re.MULTILINE | re.DOTALL)
152
153
# Pre-tokenize text and convert to byte tokens
154
tokens: list[list[bytes]] = []
155
for match in regex.finditer(text):
156
piece = match.group(0)
157
byte_tokens = [bytes([b]) for b in piece.encode("utf-8")]
158
tokens.append(byte_tokens)
159
160
vocab: dict[int, bytes] = {}
161
162
# Assign token IDs to base vocabulary (all possible single bytes 0-255)
163
for idx in range(256):
164
vocab[idx] = bytes([idx])
165
166
merges: list[tuple[bytes, bytes]] = []
167
next_token_id = len(vocab)
168
169
# BPE training loop
170
while next_token_id < vocab_size:
171
# Find most frequent adjacent pair
172
most_frequent_pair = find_most_frequent_pair(tokens)
173
if most_frequent_pair is None:
174
break
175
176
# Add merge to list
177
merges.append(most_frequent_pair)
178
179
# Create new merged token
180
merged_token = most_frequent_pair[0] + most_frequent_pair[1]
181
vocab[next_token_id] = merged_token
182
183
# Replace all occurrences of the pair with the new token
184
for token_list in tokens:
185
i = 0
186
while i < len(token_list) - 1:
187
if (token_list[i], token_list[i + 1]) == most_frequent_pair:
188
token_list[i : i + 2] = [merged_token]
189
else:
190
i += 1
191
192
next_token_id += 1
193
194
return vocab, merges