Merge compiler into model.py
This commit is contained in:
parent
a252a15e9d
commit
7b7ef39d0b
|
@ -1,92 +0,0 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
from typing import Iterable, Mapping, List, Dict, Tuple
|
||||
import gptc.tokenizer
|
||||
import gptc.model
|
||||
|
||||
|
||||
def _count_words(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int,
|
||||
hash_algorithm: str,
|
||||
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
category_lengths: Dict[str, int] = {}
|
||||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.hash_list(
|
||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||
hash_algorithm,
|
||||
)
|
||||
category = portion["category"]
|
||||
|
||||
if not category in names:
|
||||
names.append(category)
|
||||
|
||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||
text
|
||||
)
|
||||
|
||||
for word in text:
|
||||
if word in word_counts:
|
||||
try:
|
||||
word_counts[word][category] += 1
|
||||
except KeyError:
|
||||
word_counts[word][category] = 1
|
||||
else:
|
||||
word_counts[word] = {category: 1}
|
||||
|
||||
return word_counts, category_lengths, names
|
||||
|
||||
|
||||
def _get_weights(
|
||||
min_count: int,
|
||||
word_counts: Dict[int, Dict[str, int]],
|
||||
category_lengths: Dict[str, int],
|
||||
names: List[str],
|
||||
) -> Dict[int, List[int]]:
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, counts in word_counts.items():
|
||||
if sum(counts.values()) >= min_count:
|
||||
weights = {
|
||||
category: value / category_lengths[category]
|
||||
for category, value in counts.items()
|
||||
}
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
return model
|
||||
|
||||
|
||||
def compile_(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int = 1,
|
||||
min_count: int = 1,
|
||||
hash_algorithm: str = "sha256",
|
||||
) -> gptc.model.Model:
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_model : list of dict
|
||||
A raw GPTC model.
|
||||
|
||||
max_ngram_length : int
|
||||
Maximum ngram lenght to compile with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A compiled GPTC model.
|
||||
|
||||
"""
|
||||
word_counts, category_lengths, names = _count_words(
|
||||
raw_model, max_ngram_length, hash_algorithm
|
||||
)
|
||||
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
|
@ -1,6 +1,8 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
from typing import (
|
||||
Iterable,
|
||||
Mapping,
|
||||
List,
|
||||
Dict,
|
||||
cast,
|
||||
|
@ -12,8 +14,63 @@ import json
|
|||
import gptc.tokenizer
|
||||
from gptc.exceptions import InvalidModelError
|
||||
import gptc.weighting
|
||||
import gptc.compiler
|
||||
|
||||
def _count_words(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int,
|
||||
hash_algorithm: str,
|
||||
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]:
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
category_lengths: Dict[str, int] = {}
|
||||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.hash_list(
|
||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||
hash_algorithm,
|
||||
)
|
||||
category = portion["category"]
|
||||
|
||||
if not category in names:
|
||||
names.append(category)
|
||||
|
||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||
text
|
||||
)
|
||||
|
||||
for word in text:
|
||||
if word in word_counts:
|
||||
try:
|
||||
word_counts[word][category] += 1
|
||||
except KeyError:
|
||||
word_counts[word][category] = 1
|
||||
else:
|
||||
word_counts[word] = {category: 1}
|
||||
|
||||
return word_counts, category_lengths, names
|
||||
|
||||
|
||||
def _get_weights(
|
||||
min_count: int,
|
||||
word_counts: Dict[int, Dict[str, int]],
|
||||
category_lengths: Dict[str, int],
|
||||
names: List[str],
|
||||
) -> Dict[int, List[int]]:
|
||||
model: Dict[int, List[int]] = {}
|
||||
for word, counts in word_counts.items():
|
||||
if sum(counts.values()) >= min_count:
|
||||
weights = {
|
||||
category: value / category_lengths[category]
|
||||
for category, value in counts.items()
|
||||
}
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
model[word] = new_weights
|
||||
return model
|
||||
|
||||
class ExplanationEntry(TypedDict):
|
||||
weight: float
|
||||
|
@ -183,7 +240,34 @@ class Model:
|
|||
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
|
||||
)
|
||||
|
||||
compile = staticmethod(gptc.compiler.compile_)
|
||||
@staticmethod
|
||||
def compile(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int = 1,
|
||||
min_count: int = 1,
|
||||
hash_algorithm: str = "sha256",
|
||||
) -> 'Model':
|
||||
"""Compile a raw model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_model : list of dict
|
||||
A raw GPTC model.
|
||||
|
||||
max_ngram_length : int
|
||||
Maximum ngram lenght to compile with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A compiled GPTC model.
|
||||
|
||||
"""
|
||||
word_counts, category_lengths, names = _count_words(
|
||||
raw_model, max_ngram_length, hash_algorithm
|
||||
)
|
||||
model = _get_weights(min_count, word_counts, category_lengths, names)
|
||||
return Model(model, names, max_ngram_length, hash_algorithm)
|
||||
|
||||
|
||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||
|
|
Loading…
Reference in New Issue
Block a user