Fix type annotations

This commit is contained in:
Samuel Sloniker 2023-04-17 18:16:20 -07:00
parent 2c3fc77ba6
commit 9513025e60
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
2 changed files with 80 additions and 32 deletions

View File

@ -3,11 +3,79 @@
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
from typing import (
Iterable,
Mapping,
List,
Dict,
Union,
cast,
BinaryIO,
Any,
Tuple,
TypedDict,
)
import json
import collections
class ExplanationEntry(TypedDict):
weight: float
probabilities: Dict[str, float]
count: int
Explanation = Dict[
str,
ExplanationEntry,
]
Log = List[Tuple[str, float, List[float]]]
def convert_log(log: Log, names: List[str]) -> Explanation:
explanation: Explanation = {}
for word2, weight, word_probs in log:
if word2 in explanation:
explanation[word2]["count"] += 1
else:
explanation[word2] = {
"weight": weight,
"probabilities": {
name: word_probs[index] for index, name in enumerate(names)
},
"count": 1,
}
return explanation
class Confidences(collections.UserDict[str, float]):
def __init__(
self,
probs: Dict[str, float],
model: Model,
text: str,
max_ngram_length: int,
):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(
self,
probs: Dict[str, float],
explanation: Explanation,
model: Model,
text: str,
max_ngram_length: int,
):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
class Model:
def __init__(
self,
@ -23,7 +91,7 @@ class Model:
def confidence(
self, text: str, max_ngram_length: int, transparent: bool = False
) -> Dict[str, float]:
) -> Confidences:
"""Classify text with confidence.
Parameters
@ -56,7 +124,7 @@ class Model:
if transparent:
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
log = []
log: Log = []
numbered_probs: Dict[int, float] = {}
@ -71,7 +139,13 @@ class Model:
)
if transparent:
log.append([token_map[word], weight, unweighted_numbers])
log.append(
(
token_map[word],
weight,
unweighted_numbers,
)
)
for category, value in enumerate(weighted_numbers):
try:
@ -88,19 +162,7 @@ class Model:
}
if transparent:
explanation = {}
for word, weight, word_probs in log:
if word in explanation:
explanation[word]["count"] += 1
else:
explanation[word] = {
"weight": weight,
"probabilities": {
name: word_probs[index]
for index, name in enumerate(self.names)
},
"count": 1,
}
explanation = convert_log(log, self.names)
return TransparentConfidences(
probs, explanation, self, text, max_ngram_length
@ -141,20 +203,6 @@ class Model:
)
class Confidences(collections.UserDict):
def __init__(self, probs, model, text, max_ngram_length):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(self, probs, explanation, model, text, max_ngram_length):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
if prefix != b"GPTC model v6\n":

View File

@ -39,7 +39,7 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
return math.sqrt(_mean(squared_deviations))
def weight(numbers: Sequence[float]) -> List[float]:
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers]