Compare commits

..

No commits in common. "d8f3d2e7013d200201e4aa6759c7bf7045562735" and "f38f4ca8013baeb56ff05acea4eeddcc2664ece7" have entirely different histories.

3 changed files with 11 additions and 63 deletions

View File

@ -5,7 +5,6 @@ from gptc.exceptions import InvalidModelError
import gptc.weighting import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
import json import json
import collections
class Model: class Model:
@ -21,9 +20,7 @@ class Model:
self.max_ngram_length = max_ngram_length self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm self.hash_algorithm = hash_algorithm
def confidence( def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
self, text: str, max_ngram_length: int, transparent: bool = False
) -> Dict[str, float]:
"""Classify text with confidence. """Classify text with confidence.
Parameters Parameters
@ -43,36 +40,19 @@ class Model:
""" """
model = self.weights model = self.weights
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
raw_tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
raw_tokens, gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
),
self.hash_algorithm, self.hash_algorithm,
) )
if transparent:
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
log = []
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
try: try:
unweighted_numbers = [ weighted_numbers = gptc.weighting.weight(
i / 65535 for i in cast(List[float], model[word]) [i / 65535 for i in cast(List[float], model[word])]
]
weight, weighted_numbers = gptc.weighting.weight(
unweighted_numbers
) )
if transparent:
log.append([token_map[word], weight, unweighted_numbers])
for category, value in enumerate(weighted_numbers): for category, value in enumerate(weighted_numbers):
try: try:
numbered_probs[category] += value numbered_probs[category] += value
@ -80,30 +60,12 @@ class Model:
numbered_probs[category] = value numbered_probs[category] = value
except KeyError: except KeyError:
pass pass
total = sum(numbered_probs.values()) total = sum(numbered_probs.values())
probs: Dict[str, float] = { probs: Dict[str, float] = {
self.names[category]: value / total self.names[category]: value / total
for category, value in numbered_probs.items() for category, value in numbered_probs.items()
} }
return probs
if transparent:
explanation = {}
for word, weight, word_probs in log:
if word in explanation:
explanation[word]["count"] += 1
else:
explanation[word] = {
"weight": weight,
"probabilities": word_probs,
"count": 1,
}
return TransparentConfidences(
probs, explanation, self, text, max_ngram_length
)
else:
return Confidences(probs, self, text, max_ngram_length)
def get(self, token: str) -> Dict[str, float]: def get(self, token: str) -> Dict[str, float]:
try: try:
@ -120,7 +82,7 @@ class Model:
} }
def serialize(self, file: BinaryIO) -> None: def serialize(self, file: BinaryIO) -> None:
file.write(b"GPTC model v6\n") file.write(b"GPTC model v5\n")
file.write( file.write(
json.dumps( json.dumps(
{ {
@ -138,23 +100,9 @@ class Model:
) )
class Confidences(collections.UserDict):
def __init__(self, probs, model, text, max_ngram_length):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(self, probs, explanation, model, text, max_ngram_length):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14) prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n": if prefix != b"GPTC model v5\n":
raise InvalidModelError() raise InvalidModelError()
config_json = b"" config_json = b""

View File

@ -7,7 +7,7 @@ import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
text = unicodedata.normalize("NFKD", text).casefold() text = unicodedata.normalize("NFKD", text).lower()
parts = [] parts = []
highest_end = 0 highest_end = 0
for emoji_part in emoji.emoji_list(text): for emoji_part in emoji.emoji_list(text):

View File

@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]:
standard_deviation = _standard_deviation(numbers) standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2 weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers] weighted_numbers = [i * weight for i in numbers]
return weight, weighted_numbers return weighted_numbers