Add classification explanations

Closes #17
This commit is contained in:
Samuel Sloniker 2023-04-16 15:35:53 -07:00
parent 99ad07a876
commit 7f68dc6fc6
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
2 changed files with 60 additions and 8 deletions

View File

@ -5,6 +5,7 @@ from gptc.exceptions import InvalidModelError
import gptc.weighting import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
import json import json
import collections
class Model: class Model:
@ -20,7 +21,9 @@ class Model:
self.max_ngram_length = max_ngram_length self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]: def confidence(
self, text: str, max_ngram_length: int, transparent: bool = False
) -> Dict[str, float]:
"""Classify text with confidence. """Classify text with confidence.
Parameters Parameters
@ -40,19 +43,36 @@ class Model:
""" """
model = self.weights model = self.weights
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
raw_tokens = gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize( raw_tokens,
text, min(max_ngram_length, self.max_ngram_length)
),
self.hash_algorithm, self.hash_algorithm,
) )
if transparent:
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
log = []
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
try: try:
weighted_numbers = gptc.weighting.weight( unweighted_numbers = [
[i / 65535 for i in cast(List[float], model[word])] i / 65535 for i in cast(List[float], model[word])
]
weight, weighted_numbers = gptc.weighting.weight(
unweighted_numbers
) )
if transparent:
log.append([token_map[word], weight, unweighted_numbers])
for category, value in enumerate(weighted_numbers): for category, value in enumerate(weighted_numbers):
try: try:
numbered_probs[category] += value numbered_probs[category] += value
@ -60,12 +80,30 @@ class Model:
numbered_probs[category] = value numbered_probs[category] = value
except KeyError: except KeyError:
pass pass
total = sum(numbered_probs.values()) total = sum(numbered_probs.values())
probs: Dict[str, float] = { probs: Dict[str, float] = {
self.names[category]: value / total self.names[category]: value / total
for category, value in numbered_probs.items() for category, value in numbered_probs.items()
} }
return probs
if transparent:
explanation = {}
for word, weight, word_probs in log:
if word in explanation:
explanation[word]["count"] += 1
else:
explanation[word] = {
"weight": weight,
"probabilities": word_probs,
"count": 1,
}
return TransparentConfidences(
probs, explanation, self, text, max_ngram_length
)
else:
return Confidences(probs, self, text, max_ngram_length)
def get(self, token: str) -> Dict[str, float]: def get(self, token: str) -> Dict[str, float]:
try: try:
@ -100,6 +138,20 @@ class Model:
) )
class Confidences(collections.UserDict):
def __init__(self, probs, model, text, max_ngram_length):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(self, probs, explanation, model, text, max_ngram_length):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14) prefix = encoded_model.read(14)
if prefix != b"GPTC model v5\n": if prefix != b"GPTC model v5\n":

View File

@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]:
standard_deviation = _standard_deviation(numbers) standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2 weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers] weighted_numbers = [i * weight for i in numbers]
return weighted_numbers return weight, weighted_numbers