Source code for tkseem._base

import os
import sys
import mmap
import pickle
import itertools
import numpy as np
from tqdm import tqdm
from pathlib import Path
from .util import split_on_binary
from collections import Counter, defaultdict

class BaseTokenizer:
    """
    Base Tokenizer that implements the basic functionalities of a tokenizer
    """

    def __init__(
        self, unk_token="<UNK>", pad_token="<PAD>", vocab_size=10000, special_tokens=[],
    ):
        """Constructor

        Args:
            unk_token (str, optional): unkown symbol. Defaults to "<UNK>".
            pad_token (str, optional): pad symbol. Defaults to "<PAD>".
            vocab_size (int, optional): max vocab size. Defaults to 10000.
            special_tokens (list, optional): user defined special tokens. Defaults to [].
        """
        self.vocab_size = vocab_size
        self.unk_token = unk_token
        self.pad_token = pad_token
        self.special_tokens = special_tokens
        self.rel_path = os.path.dirname(__file__)
        cach_dict_path = os.path.join(self.rel_path, "dictionaries/cached.pl")
        self.cached = pickle.load(open(cach_dict_path, "rb"))

    def _get_tokens_frequency_quickly(self, file_path):
        """
        Get the tokens frequency quickly using memory mapping

        Args:
            file_path (str): the directory of the data to read
        
        Returns:
            Dict: frequency based dictionary   
        """
        encoding = "utf8"
        with open(file_path, "r", encoding=encoding, errors="ignore") as f:
            with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as m:
                m.read(0)
                i = 0
                size_to_read = int(1e9)
                freq = Counter([])
                pbar = tqdm(total=int(m.size() / size_to_read))
                while i < m.size():
                    cur_txt = ""
                    data = m.read(size_to_read)
                    i += size_to_read
                    try:
                        cur_txt = data.decode(encoding)
                    except:
                        cur_txt = (data + m.read(1)).decode(encoding)
                        i += 1
                    freq.update(cur_txt.split(" "))
                    pbar.update(1)
        return freq

    def _get_tokens_frequency(self, file_path):
        """
        Get tokens frequency using a dictionary

        Args:
            file_path (str): file path to read
        Returns:
            dict : dict containing frequency
        """
        text = open(file_path, "r").read()
        tokens_frequency = defaultdict(int)
        for word in text.split(" "):
            tokens_frequency[word] += 1
        return dict(tokens_frequency)

    def _split_word(self, word, number_of_subwords):
        """Split a word into a specific number of sub-words

        Args:
            word (str): word input
            number_of_subwords (int): number of subtokens to generate from the word 
        
        Returns:
            list: list of subwords 
        """
        assert number_of_subwords > 0

        def _split(_word, _number_of_subwords):
            groups = []
            if _number_of_subwords == 1:
                groups.append(["##" + _word])
            else:
                for i in range(1, len(_word), 1):
                    groups.extend(
                        ["##" + _word[:i], *group]
                        for group in _split(_word[i:], _number_of_subwords - 1)
                        if len(group) == _number_of_subwords - 1
                    )
            return groups

        groups_of_subwords = _split(word, number_of_subwords)
        out_groups = []
        for group in groups_of_subwords:
            group[0] = group[0].replace("##", "")
            out_groups.append(group)
        return out_groups

    def _split_word_cached(self, word, number_of_subwords):
        """Faster version of word splitting

        Args:
            word (word): word to be split
            number_of_subwords (int): number of subwords to split the word to

        Returns:
            list: subwords
        """
        if number_of_subwords == 1:
            return [[word]]
        n = len(word) - 1
        all_binaries = self.cached[n, number_of_subwords - 1]
        return [split_on_binary(word, binary) for binary in all_binaries]

    def _tokenize_from_dict_deprecated(self, text, freq_dict, cache=False, max_size=20):
        """Tokenize using frequency based approach given a dictionary

        Args:
            text (str): input string
            freq_dict (dict): frequency dictionary
            cache (bool, optional): faster approach. Defaults to False.
            max_size (int, optional): maximum word size. Defaults to 20.

        Returns:
            [type]: [description]
        """
        assert freq_dict
        tokens = []
        output_tokens = []
        for word in text.split():
            if len(word) >= max_size:
                print(f"{word} is too long ...")
                output_tokens.append(self.unk_token)
                continue
            if word in freq_dict:
                output_tokens.append(word)
            else:
                groups_of_valid_subwords = []
                for i in range(2, len(word) + 1, 1):
                    if cache:
                        groups_of_subwords = self._split_word_cached(word, i)
                    else:
                        groups_of_subwords = self._split_word(word, i)

                    # filter out groups
                    groups_of_valid_subwords = list(
                        filter(
                            lambda group: all(
                                subword in freq_dict.keys() for subword in group
                            ),
                            groups_of_subwords,
                        )
                    )
                    if groups_of_valid_subwords:
                        break
                if len(groups_of_valid_subwords) == 0:
                    output_tokens.append(self.unk_token)
                else:
                    sorted_groups_of_valid_subwords = sorted(
                        groups_of_valid_subwords,
                        key=lambda group: sum(freq_dict[subword] for subword in group),
                    )
                    tokens = sorted_groups_of_valid_subwords[-1]
                    for token in tokens:
                        output_tokens.append(str(token))
        return output_tokens

    #https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/tokenization.py#L308
    def _tokenize_from_dict(self, text, freq_dict, use_cache, max_cache_size, max_word_size=20):
        """Tokenize using frequency based approach given a dictionary

        Args:
            text (str): text to tokenize
            freq_dict (dict): a frequency dictionary
            use_cache (bool): whether to use caching 
            max_cache_size (int): max size for the caching dictionary
            max_word_size (int, optional): max word size. Defaults to 20.

        Returns:
            [type]: [description]
        """

        output_tokens = []
        cache = {}
        num_tokens = 0 
        num_found_tokens = 0 
        for token in text.split():
            num_tokens += 1
            chars = list(token)
            if len(chars) > max_word_size:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []

            if use_cache:
                if token in cache:
                    output_tokens.extend(cache[token])
                    num_found_tokens += 1
                    continue
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    if substr in freq_dict:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end
            if is_bad:
                sub_tokens = [self.unk_token]
            output_tokens.extend(sub_tokens)           
            if use_cache:
                if len(cache) < max_cache_size:
                    cache[token] = sub_tokens
        # print('Percentage of cached tokens  = ', num_found_tokens/num_tokens)
        return output_tokens

    def _truncate_dict(self, freq_dict):
        """Truncate a frequency dictionary and add reserved tokens

        Args:
            freq_dict (dict): frequency dictionary

        Returns:
            dict: truncated dictionary based on the vocab size
        """
        sorted_tokens_frequency = {
            k: v for k, v in sorted(freq_dict.items(), key=lambda x: x[1], reverse=True)
        }

        limited_tokens_frequency = dict()
        limited_tokens_frequency[self.unk_token] = -1
        limited_tokens_frequency[self.pad_token] = -1
        for token in self.special_tokens:
            limited_tokens_frequency[token] = -1
        limited_tokens_frequency.update(
            {
                k: v
                for k, v in list(sorted_tokens_frequency.items())[
                    : self.vocab_size - len(limited_tokens_frequency)
                ]
            }
        )
        return limited_tokens_frequency

    def token_to_id(self, piece):
        """ Get tokens list

        Returns:
            list: tokens 
        """
        return list(self.vocab.keys()).index(piece)

    def id_to_token(self, id):
        """convert id to token

        Args:
            id (int): input id

        Returns:
            str: token
        """
        return list(self.vocab.keys())[id]

    def tokenize(self, text, use_cache = False, max_cache_size = 1000):
        """tokenize

        Args:
            text (str): input text
            use_cache (bool, optional): speed up using caching. Defaults to False.
            max_cache_size (int, optional): max cacne size. Defaults to 1000.

        Returns:
            list: output list of tokens
        """
        output_tokens = self._tokenize_from_dict(text, self.vocab, use_cache, max_cache_size = max_cache_size)
        return output_tokens

    def detokenize(self, tokens):
        """ Convert tokens to a string

        Args:
            tokens (list): list of tokens

        Returns:
            str: detokenized string
        """
        detokenized = "".join(tokens).replace("##", "")
        return detokenized

    def decode(self, encoded):
        """ Decode ids

        Args:
            encoded (list): list of ids to decode

        Returns:
            list: tokens
        """
        decoded = [self.id_to_token(id) for id in encoded]
        return decoded

    def encode(self, text):
        """ Convert string to a list of ids

        Args:
            text (str): input string

        Returns:
            list: list of ids
        """
        tokens = self.tokenize(text)
        encoded = [self.token_to_id(token) for token in tokens]
        return encoded

    def encode_sentences(self, sentences, boundries=("", ""), out_length=None):
        """
        Encode a list of sentences using the trained model

        Args:
            sentences (list): list of sentences
            boundries (tuple): boundries for each sentence. 
            out_length (int, optional): specify the max length of encodings. Defaults to 100.

        Returns:
            [np.array]: numpy array of encodings
        """
        encodings = []
        for sent in sentences:
            encoded = self.encode(boundries[0] + " " + sent + " " + boundries[1])
            encodings.append(encoded)

        pad_id = self.encode(self.pad_token)[0]

        # pad to equal size from https://stackoverflow.com/a/38619333
        encodings = np.array(
            list(itertools.zip_longest(*encodings, fillvalue=pad_id))
        ).T

        # increase pad if necessary
        if not (out_length is None):
            if out_length > encodings.shape[1]:
                encodings = np.pad(
                    encodings,
                    [(0, 0), (0, out_length)],
                    constant_values=pad_id,
                    mode="constant",
                )
        encodings = encodings[..., :out_length]

        return encodings

    def load_model(self, file_path):
        """Load a saved model as a frequency dictionary

        Args:
            file_path (str): file path of the dictionary
        """
        print("Loading as pickle file ...")
        self.vocab = pickle.load(open(file_path, "rb"))

    def save_model(self, file_path):
        """Save a model as a freqency dictionary

        Args:
            file_path (str): file path to save the model
        """
        assert self.vocab
        with open(f"{file_path}", "wb") as pickle_file:
            print("Saving as pickle file ...")
            pickle.dump(self.vocab, pickle_file)
            
    def __str__(self):
        return f"{self.__class__.__name__}"