Compare commits

..

3 Commits

Author SHA1 Message Date
d8f3d2e701
Bump model version
99ad07a876 broke the model format,
although probably only in a few edge cases

Still enough of a change for a model version bump
2023-04-16 15:36:49 -07:00
7f68dc6fc6
Add classification explanations
Closes #17
2023-04-16 15:35:53 -07:00
99ad07a876
Casefold
Closes #14
2023-04-16 14:49:03 -07:00
3 changed files with 63 additions and 11 deletions

View File

@ -5,6 +5,7 @@ from gptc.exceptions import InvalidModelError
import gptc.weighting import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
import json import json
import collections
class Model: class Model:
@ -20,7 +21,9 @@ class Model:
self.max_ngram_length = max_ngram_length self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]: def confidence(
self, text: str, max_ngram_length: int, transparent: bool = False
) -> Dict[str, float]:
"""Classify text with confidence. """Classify text with confidence.
Parameters Parameters
@ -40,19 +43,36 @@ class Model:
""" """
model = self.weights model = self.weights
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
raw_tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize( raw_tokens,
text, min(max_ngram_length, self.max_ngram_length)
),
self.hash_algorithm, self.hash_algorithm,
) )
if transparent:
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
log = []
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
try: try:
weighted_numbers = gptc.weighting.weight( unweighted_numbers = [
[i / 65535 for i in cast(List[float], model[word])] i / 65535 for i in cast(List[float], model[word])
]
weight, weighted_numbers = gptc.weighting.weight(
unweighted_numbers
) )
if transparent:
log.append([token_map[word], weight, unweighted_numbers])
for category, value in enumerate(weighted_numbers): for category, value in enumerate(weighted_numbers):
try: try:
numbered_probs[category] += value numbered_probs[category] += value
@ -60,12 +80,30 @@ class Model:
numbered_probs[category] = value numbered_probs[category] = value
except KeyError: except KeyError:
pass pass
total = sum(numbered_probs.values()) total = sum(numbered_probs.values())
probs: Dict[str, float] = { probs: Dict[str, float] = {
self.names[category]: value / total self.names[category]: value / total
for category, value in numbered_probs.items() for category, value in numbered_probs.items()
} }
return probs
if transparent:
explanation = {}
for word, weight, word_probs in log:
if word in explanation:
explanation[word]["count"] += 1
else:
explanation[word] = {
"weight": weight,
"probabilities": word_probs,
"count": 1,
}
return TransparentConfidences(
probs, explanation, self, text, max_ngram_length
)
else:
return Confidences(probs, self, text, max_ngram_length)
def get(self, token: str) -> Dict[str, float]: def get(self, token: str) -> Dict[str, float]:
try: try:
@ -82,7 +120,7 @@ class Model:
} }
def serialize(self, file: BinaryIO) -> None: def serialize(self, file: BinaryIO) -> None:
file.write(b"GPTC model v5\n") file.write(b"GPTC model v6\n")
file.write( file.write(
json.dumps( json.dumps(
{ {
@ -100,9 +138,23 @@ class Model:
) )
class Confidences(collections.UserDict):
def __init__(self, probs, model, text, max_ngram_length):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(self, probs, explanation, model, text, max_ngram_length):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14) prefix = encoded_model.read(14)
if prefix != b"GPTC model v5\n": if prefix != b"GPTC model v6\n":
raise InvalidModelError() raise InvalidModelError()
config_json = b"" config_json = b""

View File

@ -7,7 +7,7 @@ import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
text = unicodedata.normalize("NFKD", text).lower() text = unicodedata.normalize("NFKD", text).casefold()
parts = [] parts = []
highest_end = 0 highest_end = 0
for emoji_part in emoji.emoji_list(text): for emoji_part in emoji.emoji_list(text):

View File

@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]:
standard_deviation = _standard_deviation(numbers) standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2 weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers] weighted_numbers = [i * weight for i in numbers]
return weighted_numbers return weight, weighted_numbers