You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
322 lines
9.0 KiB
322 lines
9.0 KiB
# SPDX-License-Identifier: GPL-3.0-or-later |
|
|
|
from typing import ( |
|
Iterable, |
|
Mapping, |
|
List, |
|
Dict, |
|
cast, |
|
BinaryIO, |
|
Tuple, |
|
TypedDict, |
|
) |
|
import json |
|
import gptc.tokenizer |
|
from gptc.exceptions import InvalidModelError |
|
import gptc.weighting |
|
|
|
def _count_words( |
|
raw_model: Iterable[Mapping[str, str]], |
|
max_ngram_length: int, |
|
hash_algorithm: str, |
|
) -> Tuple[Dict[int, Dict[str, int]], Dict[str, int], List[str]]: |
|
word_counts: Dict[int, Dict[str, int]] = {} |
|
category_lengths: Dict[str, int] = {} |
|
names: List[str] = [] |
|
|
|
for portion in raw_model: |
|
text = gptc.tokenizer.hash_list( |
|
gptc.tokenizer.tokenize(portion["text"], max_ngram_length), |
|
hash_algorithm, |
|
) |
|
category = portion["category"] |
|
|
|
if not category in names: |
|
names.append(category) |
|
|
|
category_lengths[category] = category_lengths.get(category, 0) + len( |
|
text |
|
) |
|
|
|
for word in text: |
|
if word in word_counts: |
|
try: |
|
word_counts[word][category] += 1 |
|
except KeyError: |
|
word_counts[word][category] = 1 |
|
else: |
|
word_counts[word] = {category: 1} |
|
|
|
return word_counts, category_lengths, names |
|
|
|
|
|
def _get_weights( |
|
min_count: int, |
|
word_counts: Dict[int, Dict[str, int]], |
|
category_lengths: Dict[str, int], |
|
names: List[str], |
|
) -> Dict[int, List[int]]: |
|
model: Dict[int, List[int]] = {} |
|
for word, counts in word_counts.items(): |
|
if sum(counts.values()) >= min_count: |
|
weights = { |
|
category: value / category_lengths[category] |
|
for category, value in counts.items() |
|
} |
|
total = sum(weights.values()) |
|
new_weights: List[int] = [] |
|
for category in names: |
|
new_weights.append( |
|
round((weights.get(category, 0) / total) * 65535) |
|
) |
|
model[word] = new_weights |
|
return model |
|
|
|
class ExplanationEntry(TypedDict): |
|
weight: float |
|
probabilities: Dict[str, float] |
|
count: int |
|
|
|
|
|
Explanation = Dict[ |
|
str, |
|
ExplanationEntry, |
|
] |
|
|
|
Log = List[Tuple[str, float, List[float]]] |
|
|
|
|
|
class Confidences(dict[str, float]): |
|
def __init__(self, probs: Dict[str, float]): |
|
dict.__init__(self, probs) |
|
|
|
|
|
class TransparentConfidences(Confidences): |
|
def __init__( |
|
self, |
|
probs: Dict[str, float], |
|
explanation: Explanation, |
|
): |
|
self.explanation = explanation |
|
Confidences.__init__(self, probs) |
|
|
|
|
|
def convert_log(log: Log, names: List[str]) -> Explanation: |
|
explanation: Explanation = {} |
|
for word2, weight, word_probs in log: |
|
if word2 in explanation: |
|
explanation[word2]["count"] += 1 |
|
else: |
|
explanation[word2] = { |
|
"weight": weight, |
|
"probabilities": { |
|
name: word_probs[index] for index, name in enumerate(names) |
|
}, |
|
"count": 1, |
|
} |
|
return explanation |
|
|
|
|
|
class Model: |
|
def __init__( |
|
self, |
|
weights: Dict[int, List[int]], |
|
names: List[str], |
|
max_ngram_length: int, |
|
hash_algorithm: str, |
|
): |
|
self.weights = weights |
|
self.names = names |
|
self.max_ngram_length = max_ngram_length |
|
self.hash_algorithm = hash_algorithm |
|
|
|
def confidence( |
|
self, text: str, max_ngram_length: int, transparent: bool = False |
|
) -> Confidences: |
|
"""Classify text with confidence. |
|
|
|
Parameters |
|
---------- |
|
text : str |
|
The text to classify |
|
|
|
max_ngram_length : int |
|
The maximum ngram length to use in classifying |
|
|
|
Returns |
|
------- |
|
dict |
|
{category:probability, category:probability...} or {} if no words |
|
matching any categories in the model were found |
|
|
|
""" |
|
|
|
model = self.weights |
|
max_ngram_length = min(self.max_ngram_length, max_ngram_length) |
|
|
|
raw_tokens = gptc.tokenizer.tokenize( |
|
text, min(max_ngram_length, self.max_ngram_length) |
|
) |
|
|
|
tokens = gptc.tokenizer.hash_list( |
|
raw_tokens, |
|
self.hash_algorithm, |
|
) |
|
|
|
if transparent: |
|
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))} |
|
log: Log = [] |
|
|
|
numbered_probs: Dict[int, float] = {} |
|
|
|
for word in tokens: |
|
try: |
|
unweighted_numbers = [ |
|
i / 65535 for i in cast(List[float], model[word]) |
|
] |
|
|
|
weight, weighted_numbers = gptc.weighting.weight( |
|
unweighted_numbers |
|
) |
|
|
|
if transparent: |
|
log.append( |
|
( |
|
token_map[word], |
|
weight, |
|
unweighted_numbers, |
|
) |
|
) |
|
|
|
for category, value in enumerate(weighted_numbers): |
|
try: |
|
numbered_probs[category] += value |
|
except KeyError: |
|
numbered_probs[category] = value |
|
except KeyError: |
|
pass |
|
|
|
total = sum(numbered_probs.values()) |
|
probs: Dict[str, float] = { |
|
self.names[category]: value / total |
|
for category, value in numbered_probs.items() |
|
} |
|
|
|
if transparent: |
|
explanation = convert_log(log, self.names) |
|
return TransparentConfidences(probs, explanation) |
|
|
|
return Confidences(probs) |
|
|
|
def get(self, token: str) -> Dict[str, float]: |
|
try: |
|
weights = self.weights[ |
|
gptc.tokenizer.hash_single( |
|
gptc.tokenizer.normalize(token), self.hash_algorithm |
|
) |
|
] |
|
except KeyError: |
|
return {} |
|
return { |
|
category: weights[index] / 65535 |
|
for index, category in enumerate(self.names) |
|
} |
|
|
|
def serialize(self, file: BinaryIO) -> None: |
|
file.write(b"GPTC model v6\n") |
|
file.write( |
|
json.dumps( |
|
{ |
|
"names": self.names, |
|
"max_ngram_length": self.max_ngram_length, |
|
"hash_algorithm": self.hash_algorithm, |
|
} |
|
).encode("utf-8") |
|
+ b"\n" |
|
) |
|
for word, weights in self.weights.items(): |
|
file.write( |
|
word.to_bytes(6, "big") |
|
+ b"".join([weight.to_bytes(2, "big") for weight in weights]) |
|
) |
|
|
|
@staticmethod |
|
def compile( |
|
raw_model: Iterable[Mapping[str, str]], |
|
max_ngram_length: int = 1, |
|
min_count: int = 1, |
|
hash_algorithm: str = "sha256", |
|
) -> 'Model': |
|
"""Compile a raw model. |
|
|
|
Parameters |
|
---------- |
|
raw_model : list of dict |
|
A raw GPTC model. |
|
|
|
max_ngram_length : int |
|
Maximum ngram lenght to compile with. |
|
|
|
Returns |
|
------- |
|
dict |
|
A compiled GPTC model. |
|
|
|
""" |
|
word_counts, category_lengths, names = _count_words( |
|
raw_model, max_ngram_length, hash_algorithm |
|
) |
|
model = _get_weights(min_count, word_counts, category_lengths, names) |
|
return Model(model, names, max_ngram_length, hash_algorithm) |
|
|
|
@staticmethod |
|
def deserialize(encoded_model: BinaryIO) -> "Model": |
|
prefix = encoded_model.read(14) |
|
if prefix != b"GPTC model v6\n": |
|
raise InvalidModelError() |
|
|
|
config_json = b"" |
|
while True: |
|
byte = encoded_model.read(1) |
|
if byte == b"\n": |
|
break |
|
|
|
if byte == b"": |
|
raise InvalidModelError() |
|
|
|
config_json += byte |
|
|
|
try: |
|
config = json.loads(config_json.decode("utf-8")) |
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc: |
|
raise InvalidModelError() from exc |
|
|
|
try: |
|
names = config["names"] |
|
max_ngram_length = config["max_ngram_length"] |
|
hash_algorithm = config["hash_algorithm"] |
|
except KeyError as exc: |
|
raise InvalidModelError() from exc |
|
|
|
if not ( |
|
isinstance(names, list) and isinstance(max_ngram_length, int) |
|
) or not all(isinstance(name, str) for name in names): |
|
raise InvalidModelError() |
|
|
|
weight_code_length = 6 + 2 * len(names) |
|
|
|
weights: Dict[int, List[int]] = {} |
|
|
|
while True: |
|
code = encoded_model.read(weight_code_length) |
|
if not code: |
|
break |
|
if len(code) != weight_code_length: |
|
raise InvalidModelError() |
|
|
|
weights[int.from_bytes(code[:6], "big")] = [ |
|
int.from_bytes(value, "big") |
|
for value in [code[x : x + 2] for x in range(6, len(code), 2)] |
|
] |
|
|
|
return Model(weights, names, max_ngram_length, hash_algorithm)
|
|
|