Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions bindings/python/scripts/sentencepiece_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import Dict, List, Tuple
from itertools import chain
from multiprocessing import Pool, cpu_count

from requests import get
from sentencepiece import SentencePieceProcessor
Expand All @@ -25,23 +27,39 @@ def __init__(self, model: str):
self.sp = SentencePieceProcessor()
self.sp.Load(model)

def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
def extract(self, num_thread = 0) -> Tuple[Dict[str, int], List[Tuple]]:
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}

# Merges
merges = []
for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
for piece_r in vocab.keys():
merge = f"{piece_l}{piece_r}"
piece_id = vocab.get(merge, None)
if piece_id:
merges += [(piece_l, piece_r, piece_id)]
if num_thread <= 0:
num_thread = cpu_count()

#Merges
results = []
with Pool(num_thread) as pool:
results = pool.starmap(self.extract_merges, [(key, vocab,) for key in vocab.keys()])

# Flatten and filter empty lists
merges = list(chain.from_iterable(filter(None, results)))

merges = sorted(merges, key=lambda val: val[2])
merges = [(val[0], val[1]) for val in merges]

return vocab, merges

"""
Multiprocessing method for merges.
"""
@staticmethod
def extract_merges(piece_l, vocab):
merges = []
for piece_r in vocab.keys():
merge = f"{piece_l}{piece_r}"
piece_id = vocab.get(merge, None)
if piece_id:
merges += [(piece_l, piece_r, piece_id)]

return merges

class YouTokenToMeExtractor:
"""
Expand Down