diff --git a/gptc/compiler.py b/gptc/compiler.py deleted file mode 100755 index 5a1166b..0000000 --- a/gptc/compiler.py +++ /dev/null @@ -1,92 +0,0 @@ -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Iterable, Mapping, List, Dict, Tuple -import gptc.tokenizer -import gptc.model - - -def _count_words( - raw_model: Iterable[Mapping[str, str]], - max_ngram_length: int, - hash_algorithm: str, -) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]: - word_counts: Dict[int, Dict[str, int]] = {} - category_lengths: Dict[str, int] = {} - names: List[str] = [] - - for portion in raw_model: - text = gptc.tokenizer.hash_list( - gptc.tokenizer.tokenize(portion["text"], max_ngram_length), - hash_algorithm, - ) - category = portion["category"] - - if not category in names: - names.append(category) - - category_lengths[category] = category_lengths.get(category, 0) + len( - text - ) - - for word in text: - if word in word_counts: - try: - word_counts[word][category] += 1 - except KeyError: - word_counts[word][category] = 1 - else: - word_counts[word] = {category: 1} - - return word_counts, category_lengths, names - - -def _get_weights( - min_count: int, - word_counts: Dict[int, Dict[str, int]], - category_lengths: Dict[str, int], - names: List[str], -) -> Dict[int, List[int]]: - model: Dict[int, List[int]] = {} - for word, counts in word_counts.items(): - if sum(counts.values()) >= min_count: - weights = { - category: value / category_lengths[category] - for category, value in counts.items() - } - total = sum(weights.values()) - new_weights: List[int] = [] - for category in names: - new_weights.append( - round((weights.get(category, 0) / total) * 65535) - ) - model[word] = new_weights - return model - - -def compile_( - raw_model: Iterable[Mapping[str, str]], - max_ngram_length: int = 1, - min_count: int = 1, - hash_algorithm: str = "sha256", -) -> gptc.model.Model: - """Compile a raw model. - - Parameters - ---------- - raw_model : list of dict - A raw GPTC model. - - max_ngram_length : int - Maximum ngram lenght to compile with. - - Returns - ------- - dict - A compiled GPTC model. - - """ - word_counts, category_lengths, names = _count_words( - raw_model, max_ngram_length, hash_algorithm - ) - model = _get_weights(min_count, word_counts, category_lengths, names) - return gptc.model.Model(model, names, max_ngram_length, hash_algorithm) diff --git a/gptc/model.py b/gptc/model.py index b17197c..b3813b1 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from typing import ( + Iterable, + Mapping, List, Dict, cast, @@ -12,8 +14,63 @@ import json import gptc.tokenizer from gptc.exceptions import InvalidModelError import gptc.weighting -import gptc.compiler +def _count_words( + raw_model: Iterable[Mapping[str, str]], + max_ngram_length: int, + hash_algorithm: str, +) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]: + word_counts: Dict[int, Dict[str, int]] = {} + category_lengths: Dict[str, int] = {} + names: List[str] = [] + + for portion in raw_model: + text = gptc.tokenizer.hash_list( + gptc.tokenizer.tokenize(portion["text"], max_ngram_length), + hash_algorithm, + ) + category = portion["category"] + + if not category in names: + names.append(category) + + category_lengths[category] = category_lengths.get(category, 0) + len( + text + ) + + for word in text: + if word in word_counts: + try: + word_counts[word][category] += 1 + except KeyError: + word_counts[word][category] = 1 + else: + word_counts[word] = {category: 1} + + return word_counts, category_lengths, names + + +def _get_weights( + min_count: int, + word_counts: Dict[int, Dict[str, int]], + category_lengths: Dict[str, int], + names: List[str], +) -> Dict[int, List[int]]: + model: Dict[int, List[int]] = {} + for word, counts in word_counts.items(): + if sum(counts.values()) >= min_count: + weights = { + category: value / category_lengths[category] + for category, value in counts.items() + } + total = sum(weights.values()) + new_weights: List[int] = [] + for category in names: + new_weights.append( + round((weights.get(category, 0) / total) * 65535) + ) + model[word] = new_weights + return model class ExplanationEntry(TypedDict): weight: float @@ -183,7 +240,34 @@ class Model: + b"".join([weight.to_bytes(2, "big") for weight in weights]) ) - compile = staticmethod(gptc.compiler.compile_) + @staticmethod + def compile( + raw_model: Iterable[Mapping[str, str]], + max_ngram_length: int = 1, + min_count: int = 1, + hash_algorithm: str = "sha256", + ) -> 'Model': + """Compile a raw model. + + Parameters + ---------- + raw_model : list of dict + A raw GPTC model. + + max_ngram_length : int + Maximum ngram lenght to compile with. + + Returns + ------- + dict + A compiled GPTC model. + + """ + word_counts, category_lengths, names = _count_words( + raw_model, max_ngram_length, hash_algorithm + ) + model = _get_weights(min_count, word_counts, category_lengths, names) + return Model(model, names, max_ngram_length, hash_algorithm) def deserialize(encoded_model: BinaryIO) -> Model: