diff --git a/gptc/compiler.py b/gptc/compiler.py index c299a4b..130f79b 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -2,32 +2,14 @@ import gptc.tokenizer import gptc.model -from typing import Iterable, Mapping, List, Dict, Union +from typing import Iterable, Mapping, List, Dict, Union, Tuple -def compile( +def _count_words( raw_model: Iterable[Mapping[str, str]], - max_ngram_length: int = 1, - min_count: int = 1, - hash_algorithm: str = "sha256", -) -> gptc.model.Model: - """Compile a raw model. - - Parameters - ---------- - raw_model : list of dict - A raw GPTC model. - - max_ngram_length : int - Maximum ngram lenght to compile with. - - Returns - ------- - dict - A compiled GPTC model. - - """ - + max_ngram_length: int, + hash_algorithm: str, +) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]: word_counts: Dict[int, Dict[str, int]] = {} category_lengths: Dict[str, int] = {} names: List[str] = [] @@ -55,6 +37,15 @@ def compile( else: word_counts[word] = {category: 1} + return word_counts, category_lengths, names + + +def _get_weights( + min_count: int, + word_counts: Dict[int, Dict[str, int]], + category_lengths: Dict[str, int], + names: List[str], +) -> Dict[int, List[int]]: model: Dict[int, List[int]] = {} for word, counts in word_counts.items(): if sum(counts.values()) >= min_count: @@ -69,5 +60,33 @@ def compile( round((weights.get(category, 0) / total) * 65535) ) model[word] = new_weights + return model + +def compile( + raw_model: Iterable[Mapping[str, str]], + max_ngram_length: int = 1, + min_count: int = 1, + hash_algorithm: str = "sha256", +) -> gptc.model.Model: + """Compile a raw model. + + Parameters + ---------- + raw_model : list of dict + A raw GPTC model. + + max_ngram_length : int + Maximum ngram lenght to compile with. + + Returns + ------- + dict + A compiled GPTC model. + + """ + word_counts, category_lengths, names = _count_words( + raw_model, max_ngram_length, hash_algorithm + ) + model = _get_weights(min_count, word_counts, category_lengths, names) return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)