from minbpe import BasicTokenizer tokenizer = BasicTokenizer() tokenizer.train(very_long_training_string, vocab_size=4096) tokenizer.encode("hello world") # string -> tokens tokenizer.decode([1000, 2000, 3000]) # tokens -> string tokenizer.save("mymodel") # writes mymodel.model and mymodel.vocab tokenizer.load("mymodel.model") # loads the model back, the vocab is just for vis
如果要使用正则方法来按类别拆分文本,就使用以下方法:
1 2 3 4 5 6 7
from minbpe import RegexTokenizer tokenizer = RegexTokenizer() tokenizer.train(very_long_training_string, vocab_size=32768) tokenizer.encode("hello world") # string -> tokens tokenizer.decode([1000, 2000, 3000]) # tokens -> string tokenizer.save("tok32k") # writes tok32k.model and tok32k.vocab tokenizer.load("tok32k.model") # loads the model back from disk
def_build_vocab(self): # 构建词表,基础词表是256个字节 vocab = {idx: bytes([idx]) for idx inrange(256)}
# (p0,p1) 是pairs for (p0, p1), idx in self.merges.items(): vocab[idx] = vocab[p0] + vocab[p1] for special, idx in self.special_tokens.items(): vocab[idx] = special.encode("utf-8") return vocab
defsave(self, file_prefix): """ 保存两个文件:file_prefix.vocab 和 file_prefix.model - model文件用于load() - vocab文件只是一个打印版本,仅供人类检查 """ # 写入文件 model_file = file_prefix + ".model" withopen(model_file, 'w') as f: # 写入版本,模式和合并 f.write("minbpe v1\n") f.write(f"{self.pattern}\n") # 写入特殊字符 f.write(f"{len(self.special_tokens)}\n") for special, idx in self.special_tokens.items(): f.write(f"{special}{idx}\n") # 合并字典 for idx1, idx2 in self.merges: f.write(f"{idx1}{idx2}\n")
# 写入词表,这个只是用来看的 vocab_file = file_prefix + ".vocab" inverted_merges = {idx: pair for pair, idx in self.merges.items()} withopen(vocab_file, "w", encoding="utf-8") as f: for idx, token in self.vocab.items(): s = render_token(token)
# 迭代地合并最常见的pair,创建新的token merges = {} # (int, int) -> int vocab = {idx: bytes([idx]) for idx inrange(256)} # int -> bytes for i inrange(num_merges): # 统计每个pair出现的次数,返回字典,key是pair,value是出现的次数 stats = get_stats(ids) # 找到出现次数最多的pair pair = max(stats, key=stats.get) # 为新的token分配一个新的id idx = 256 + i # 用idx替换ids中所有的pair ids = merge(ids, pair, idx) # 保存合并 merges[pair] = idx vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # 打印 if verbose: print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# 保存类变量 self.merges = merges # used in encode() self.vocab = vocab # used in decode()
defdecode(self, ids): # 解码,输入int组成的列表,返回字符串 text_bytes = b"".join(self.vocab[idx] for idx in ids) text = text_bytes.decode("utf-8", errors="replace") return text
# 输入文本预处理 ids = [list(ch.encode("utf-8")) for ch in text_chunks]
# 迭代将最常见的组合合并为新的标记 merges = {} # (int, int) -> int vocab = {idx: bytes([idx]) for idx inrange(256)} # idx -> bytes for i inrange(num_merges): # 计算每个连续组合出现的次数 stats = {} for chunk_ids in ids: # 传入stats将在原地更新它,累加计数 get_stats(chunk_ids, stats) # 找到计数最高的组合 pair = max(stats, key=stats.get) # 铸造一个新的标记:分配下一个可用的id idx = 256 + i # 用idx替换ids中所有pair的出现 ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] # 保存merge merges[pair] = idx vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # 打印 if verbose: print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# 保存 self.merges = merges # used in encode() self.vocab = vocab # used in decode()
defregister_special_tokens(self, special_tokens): # special_tokens: 一个特殊的字典 str -> int # 例如: {"<|endoftext|>": 100257} self.special_tokens = special_tokens self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
defdecode(self, ids): part_bytes = [] for idx in ids: if idx in self.vocab: part_bytes.append(self.vocab[idx]) elif idx in self.inverse_special_tokens: part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) else: raise ValueError(f"invalid token id: {idx}") text_bytes = b"".join(part_bytes) text = text_bytes.decode("utf-8", errors="replace") return text
defencode(self, text, allowed_special="none_raise"): """ 与encode_ordinary不同,此函数处理特殊token。 allowed_special: 可以是"all"|"none"|"none_raise"或特殊token的自定义集合 如果none_raise,则在文本中遇到任何特殊token时会引发错误 """ # decode the user desire w.r.t. handling of special tokens special = None if allowed_special == "all": special = self.special_tokens elif allowed_special == "none": special = {} elif allowed_special == "none_raise": special = {} assertall(token notin text for token in self.special_tokens) elifisinstance(allowed_special, set): special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} else: raise ValueError(f"allowed_special={allowed_special} not understood") ifnot special: # 如果没有special token,就使用ordinary encoding return self.encode_ordinary(text) # 否则,我们必须小心处理文本中可能的特殊token # 我们通过在文本中出现任何特殊token的确切匹配来处理特殊token # 我们可以使用re.split来实现这一点。请注意,将模式括在()中 # 使其成为捕获组,因此特殊token将被包括在内 special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" special_chunks = re.split(special_pattern, text) # 现在所有特殊字符都与文本的其余部分分开 # 所有文本块都是分开编码的,然后结果是连接的 ids = [] for part in special_chunks: if part in special: # 这是一个特殊的标记,将其单独编码为特殊情况 ids.append(special[part]) else: # 这是一个普通的序列,正常编码 ids.extend(self.encode_ordinary(part)) return ids
def__init__(self): super().__init__(pattern=GPT4_SPLIT_PATTERN) # 获取官方tokenizer和merges enc = tiktoken.get_encoding("cl100k_base") mergeable_ranks = enc._mergeable_ranks # the merges are those of gpt4, but we have to recover them self.merges = recover_merges(mergeable_ranks) # 从merges重建vocab vocab = {idx: bytes([idx]) for idx inrange(256)} for (p0, p1), idx in self.merges.items(): vocab[idx] = vocab[p0] + vocab[p1] self.vocab = vocab
# 由于某种原因,与单个字节对应的标记以不同的顺序排列。 self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i inrange(256)} self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()} # 注册special tokens self.register_special_tokens(GPT4_SPECIAL_TOKENS)
def_encode_chunk(self, text_bytes): # 在我们开始处理字节之前,我们必须对它们进行排列 text_bytes = bytes(self.byte_shuffle[b] for b in text_bytes) ids = super()._encode_chunk(text_bytes) return ids
defdecode(self, ids): # 我们必须在解码之前对字节进行反排列 text_bytes = b"".join(self.vocab[idx] for idx in ids) text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes) text = text_bytes.decode("utf-8", errors="replace") return text