diff --git a/bindings/python/scripts/sentencepiece_extractor.py b/bindings/python/scripts/sentencepiece_extractor.py index a7bce9b49..89a691ccc 100644 --- a/bindings/python/scripts/sentencepiece_extractor.py +++ b/bindings/python/scripts/sentencepiece_extractor.py @@ -5,6 +5,8 @@ from os.path import exists from tempfile import NamedTemporaryFile from typing import Dict, List, Tuple +from itertools import chain +from multiprocessing import Pool, cpu_count from requests import get from sentencepiece import SentencePieceProcessor @@ -25,23 +27,39 @@ def __init__(self, model: str): self.sp = SentencePieceProcessor() self.sp.Load(model) - def extract(self) -> Tuple[Dict[str, int], List[Tuple]]: + def extract(self, num_thread = 0) -> Tuple[Dict[str, int], List[Tuple]]: sp = self.sp vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())} - # Merges - merges = [] - for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()): - for piece_r in vocab.keys(): - merge = f"{piece_l}{piece_r}" - piece_id = vocab.get(merge, None) - if piece_id: - merges += [(piece_l, piece_r, piece_id)] + if num_thread <= 0: + num_thread = cpu_count() + + #Merges + results = [] + with Pool(num_thread) as pool: + results = pool.starmap(self.extract_merges, [(key, vocab,) for key in vocab.keys()]) + + # Flatten and filter empty lists + merges = list(chain.from_iterable(filter(None, results))) + merges = sorted(merges, key=lambda val: val[2]) merges = [(val[0], val[1]) for val in merges] return vocab, merges + """ + Multiprocessing method for merges. + """ + @staticmethod + def extract_merges(piece_l, vocab): + merges = [] + for piece_r in vocab.keys(): + merge = f"{piece_l}{piece_r}" + piece_id = vocab.get(merge, None) + if piece_id: + merges += [(piece_l, piece_r, piece_id)] + + return merges class YouTokenToMeExtractor: """