Split compiler into two functions
This commit is contained in:
parent
071656c2d2
commit
75fdb5ba3c
|
@ -2,32 +2,14 @@
|
|||
|
||||
import gptc.tokenizer
|
||||
import gptc.model
|
||||
from typing import Iterable, Mapping, List, Dict, Union
|
||||
from typing import Iterable, Mapping, List, Dict, Union, Tuple
|
||||
|
||||
|
||||
def compile(
|
||||
def _count_words(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int = 1,
|
||||
min_count: int = 1,
|
||||
hash_algorithm: str = "sha256",
|
||||
) -> gptc.model.Model:
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_model : list of dict
|
||||
A raw GPTC model.
|
||||
|
||||
max_ngram_length : int
|
||||
Maximum ngram lenght to compile with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A compiled GPTC model.
|
||||
|
||||
"""
|
||||
|
||||
max_ngram_length: int,
|
||||
hash_algorithm: str,
|
||||
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
category_lengths: Dict[str, int] = {}
|
||||
names: List[str] = []
|
||||
|
@ -55,6 +37,15 @@ def compile(
|
|||
else:
|
||||
word_counts[word] = {category: 1}
|
||||
|
||||
return word_counts, category_lengths, names
|
||||
|
||||
|
||||
def _get_weights(
|
||||
min_count: int,
|
||||
word_counts: Dict[int, Dict[str, int]],
|
||||
category_lengths: Dict[str, int],
|
||||
names: List[str],
|
||||
) -> Dict[int, List[int]]:
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, counts in word_counts.items():
|
||||
if sum(counts.values()) >= min_count:
|
||||
|
@ -69,5 +60,33 @@ def compile(
|
|||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
return model
|
||||
|
||||
|
||||
def compile(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int = 1,
|
||||
min_count: int = 1,
|
||||
hash_algorithm: str = "sha256",
|
||||
) -> gptc.model.Model:
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_model : list of dict
|
||||
A raw GPTC model.
|
||||
|
||||
max_ngram_length : int
|
||||
Maximum ngram lenght to compile with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A compiled GPTC model.
|
||||
|
||||
"""
|
||||
word_counts, category_lengths, names = _count_words(
|
||||
raw_model, max_ngram_length, hash_algorithm
|
||||
)
|
||||
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
||||
|
|
Loading…
Reference in New Issue
Block a user