Merge compiler into model.py

This commit is contained in:
Samuel Sloniker 2023-04-17 21:15:18 -07:00
parent a252a15e9d
commit 7b7ef39d0b
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
2 changed files with 86 additions and 94 deletions

View File

@ -1,92 +0,0 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Iterable, Mapping, List, Dict, Tuple
import gptc.tokenizer
import gptc.model
def _count_words(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int,
hash_algorithm: str,
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash_list(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category = portion["category"]
if not category in names:
names.append(category)
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for word in text:
if word in word_counts:
try:
word_counts[word][category] += 1
except KeyError:
word_counts[word][category] = 1
else:
word_counts[word] = {category: 1}
return word_counts, category_lengths, names
def _get_weights(
min_count: int,
word_counts: Dict[int, Dict[str, int]],
category_lengths: Dict[str, int],
names: List[str],
) -> Dict[int, List[int]]:
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return model
def compile_(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
word_counts, category_lengths, names = _count_words(
raw_model, max_ngram_length, hash_algorithm
)
model = _get_weights(min_count, word_counts, category_lengths, names)
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import (
Iterable,
Mapping,
List,
Dict,
cast,
@ -12,8 +14,63 @@ import json
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
import gptc.compiler
def _count_words(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int,
hash_algorithm: str,
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash_list(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category = portion["category"]
if not category in names:
names.append(category)
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for word in text:
if word in word_counts:
try:
word_counts[word][category] += 1
except KeyError:
word_counts[word][category] = 1
else:
word_counts[word] = {category: 1}
return word_counts, category_lengths, names
def _get_weights(
min_count: int,
word_counts: Dict[int, Dict[str, int]],
category_lengths: Dict[str, int],
names: List[str],
) -> Dict[int, List[int]]:
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return model
class ExplanationEntry(TypedDict):
weight: float
@ -183,7 +240,34 @@ class Model:
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
)
compile = staticmethod(gptc.compiler.compile_)
@staticmethod
def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> 'Model':
"""Compile a raw model.
Parameters
----------
raw_model : list of dict
A raw GPTC model.
max_ngram_length : int
Maximum ngram lenght to compile with.
Returns
-------
dict
A compiled GPTC model.
"""
word_counts, category_lengths, names = _count_words(
raw_model, max_ngram_length, hash_algorithm
)
model = _get_weights(min_count, word_counts, category_lengths, names)
return Model(model, names, max_ngram_length, hash_algorithm)
def deserialize(encoded_model: BinaryIO) -> Model: