parent
99ad07a876
commit
7f68dc6fc6
|
@ -5,6 +5,7 @@ from gptc.exceptions import InvalidModelError
|
|||
import gptc.weighting
|
||||
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
||||
import json
|
||||
import collections
|
||||
|
||||
|
||||
class Model:
|
||||
|
@ -20,7 +21,9 @@ class Model:
|
|||
self.max_ngram_length = max_ngram_length
|
||||
self.hash_algorithm = hash_algorithm
|
||||
|
||||
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
||||
def confidence(
|
||||
self, text: str, max_ngram_length: int, transparent: bool = False
|
||||
) -> Dict[str, float]:
|
||||
"""Classify text with confidence.
|
||||
|
||||
Parameters
|
||||
|
@ -40,19 +43,36 @@ class Model:
|
|||
"""
|
||||
|
||||
model = self.weights
|
||||
max_ngram_length = min(self.max_ngram_length, max_ngram_length)
|
||||
|
||||
raw_tokens = gptc.tokenizer.tokenize(
|
||||
text, min(max_ngram_length, self.max_ngram_length)
|
||||
)
|
||||
|
||||
tokens = gptc.tokenizer.hash(
|
||||
gptc.tokenizer.tokenize(
|
||||
text, min(max_ngram_length, self.max_ngram_length)
|
||||
),
|
||||
raw_tokens,
|
||||
self.hash_algorithm,
|
||||
)
|
||||
|
||||
if transparent:
|
||||
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
|
||||
log = []
|
||||
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
|
||||
for word in tokens:
|
||||
try:
|
||||
weighted_numbers = gptc.weighting.weight(
|
||||
[i / 65535 for i in cast(List[float], model[word])]
|
||||
unweighted_numbers = [
|
||||
i / 65535 for i in cast(List[float], model[word])
|
||||
]
|
||||
|
||||
weight, weighted_numbers = gptc.weighting.weight(
|
||||
unweighted_numbers
|
||||
)
|
||||
|
||||
if transparent:
|
||||
log.append([token_map[word], weight, unweighted_numbers])
|
||||
|
||||
for category, value in enumerate(weighted_numbers):
|
||||
try:
|
||||
numbered_probs[category] += value
|
||||
|
@ -60,12 +80,30 @@ class Model:
|
|||
numbered_probs[category] = value
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
total = sum(numbered_probs.values())
|
||||
probs: Dict[str, float] = {
|
||||
self.names[category]: value / total
|
||||
for category, value in numbered_probs.items()
|
||||
}
|
||||
return probs
|
||||
|
||||
if transparent:
|
||||
explanation = {}
|
||||
for word, weight, word_probs in log:
|
||||
if word in explanation:
|
||||
explanation[word]["count"] += 1
|
||||
else:
|
||||
explanation[word] = {
|
||||
"weight": weight,
|
||||
"probabilities": word_probs,
|
||||
"count": 1,
|
||||
}
|
||||
|
||||
return TransparentConfidences(
|
||||
probs, explanation, self, text, max_ngram_length
|
||||
)
|
||||
else:
|
||||
return Confidences(probs, self, text, max_ngram_length)
|
||||
|
||||
def get(self, token: str) -> Dict[str, float]:
|
||||
try:
|
||||
|
@ -100,6 +138,20 @@ class Model:
|
|||
)
|
||||
|
||||
|
||||
class Confidences(collections.UserDict):
|
||||
def __init__(self, probs, model, text, max_ngram_length):
|
||||
collections.UserDict.__init__(self, probs)
|
||||
self.model = model
|
||||
self.text = text
|
||||
self.max_ngram_length = max_ngram_length
|
||||
|
||||
|
||||
class TransparentConfidences(Confidences):
|
||||
def __init__(self, probs, explanation, model, text, max_ngram_length):
|
||||
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
||||
self.explanation = explanation
|
||||
|
||||
|
||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||
prefix = encoded_model.read(14)
|
||||
if prefix != b"GPTC model v5\n":
|
||||
|
|
|
@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]:
|
|||
standard_deviation = _standard_deviation(numbers)
|
||||
weight = standard_deviation * 2
|
||||
weighted_numbers = [i * weight for i in numbers]
|
||||
return weighted_numbers
|
||||
return weight, weighted_numbers
|
||||
|
|
Loading…
Reference in New Issue
Block a user