Split compiler into two functions

This commit is contained in:
Samuel Sloniker 2023-01-15 09:39:35 -08:00
parent 071656c2d2
commit 75fdb5ba3c
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62

View File

@ -2,32 +2,14 @@
import gptc.tokenizer
import gptc.model
from typing import Iterable, Mapping, List, Dict, Union
from typing import Iterable, Mapping, List, Dict, Union, Tuple
def compile(
def _count_words(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
max_ngram_length: int,
hash_algorithm: str,
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
@ -55,6 +37,15 @@ def compile(
else:
word_counts[word] = {category: 1}
return word_counts, category_lengths, names
def _get_weights(
min_count: int,
word_counts: Dict[int, Dict[str, int]],
category_lengths: Dict[str, int],
names: List[str],
) -> Dict[int, List[int]]:
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
@ -69,5 +60,33 @@ def compile(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return model
def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
word_counts, category_lengths, names = _count_words(
raw_model, max_ngram_length, hash_algorithm
)
model = _get_weights(min_count, word_counts, category_lengths, names)
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)