Compare commits

..

3 Commits

Author SHA1 Message Date
d8f3d2e701
Bump model version
99ad07a876 broke the model format,
although probably only in a few edge cases

Still enough of a change for a model version bump
2023-04-16 15:36:49 -07:00
7f68dc6fc6
Add classification explanations
Closes #17
2023-04-16 15:35:53 -07:00
99ad07a876
Casefold
Closes #14
2023-04-16 14:49:03 -07:00
3 changed files with 63 additions and 11 deletions

View File

@ -5,6 +5,7 @@ from gptc.exceptions import InvalidModelError
import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
import json
import collections
class Model:
@ -20,7 +21,9 @@ class Model:
self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
def confidence(
self, text: str, max_ngram_length: int, transparent: bool = False
) -> Dict[str, float]:
"""Classify text with confidence.
Parameters
@ -40,19 +43,36 @@ class Model:
"""
model = self.weights
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
raw_tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
),
raw_tokens,
self.hash_algorithm,
)
if transparent:
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
log = []
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:
weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in cast(List[float], model[word])]
unweighted_numbers = [
i / 65535 for i in cast(List[float], model[word])
]
weight, weighted_numbers = gptc.weighting.weight(
unweighted_numbers
)
if transparent:
log.append([token_map[word], weight, unweighted_numbers])
for category, value in enumerate(weighted_numbers):
try:
numbered_probs[category] += value
@ -60,12 +80,30 @@ class Model:
numbered_probs[category] = value
except KeyError:
pass
total = sum(numbered_probs.values())
probs: Dict[str, float] = {
self.names[category]: value / total
for category, value in numbered_probs.items()
}
return probs
if transparent:
explanation = {}
for word, weight, word_probs in log:
if word in explanation:
explanation[word]["count"] += 1
else:
explanation[word] = {
"weight": weight,
"probabilities": word_probs,
"count": 1,
}
return TransparentConfidences(
probs, explanation, self, text, max_ngram_length
)
else:
return Confidences(probs, self, text, max_ngram_length)
def get(self, token: str) -> Dict[str, float]:
try:
@ -82,7 +120,7 @@ class Model:
}
def serialize(self, file: BinaryIO) -> None:
file.write(b"GPTC model v5\n")
file.write(b"GPTC model v6\n")
file.write(
json.dumps(
{
@ -100,9 +138,23 @@ class Model:
)
class Confidences(collections.UserDict):
def __init__(self, probs, model, text, max_ngram_length):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(self, probs, explanation, model, text, max_ngram_length):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
if prefix != b"GPTC model v5\n":
if prefix != b"GPTC model v6\n":
raise InvalidModelError()
config_json = b""

View File

@ -7,7 +7,7 @@ import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
text = unicodedata.normalize("NFKD", text).lower()
text = unicodedata.normalize("NFKD", text).casefold()
parts = []
highest_end = 0
for emoji_part in emoji.emoji_list(text):

View File

@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]:
standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers]
return weighted_numbers
return weight, weighted_numbers