Fix type annotations
This commit is contained in:
parent
2c3fc77ba6
commit
9513025e60
110
gptc/model.py
110
gptc/model.py
|
@ -3,11 +3,79 @@
|
|||
import gptc.tokenizer
|
||||
from gptc.exceptions import InvalidModelError
|
||||
import gptc.weighting
|
||||
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
||||
from typing import (
|
||||
Iterable,
|
||||
Mapping,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
cast,
|
||||
BinaryIO,
|
||||
Any,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
)
|
||||
import json
|
||||
import collections
|
||||
|
||||
|
||||
class ExplanationEntry(TypedDict):
|
||||
weight: float
|
||||
probabilities: Dict[str, float]
|
||||
count: int
|
||||
|
||||
|
||||
Explanation = Dict[
|
||||
str,
|
||||
ExplanationEntry,
|
||||
]
|
||||
|
||||
Log = List[Tuple[str, float, List[float]]]
|
||||
|
||||
|
||||
def convert_log(log: Log, names: List[str]) -> Explanation:
|
||||
explanation: Explanation = {}
|
||||
for word2, weight, word_probs in log:
|
||||
if word2 in explanation:
|
||||
explanation[word2]["count"] += 1
|
||||
else:
|
||||
explanation[word2] = {
|
||||
"weight": weight,
|
||||
"probabilities": {
|
||||
name: word_probs[index] for index, name in enumerate(names)
|
||||
},
|
||||
"count": 1,
|
||||
}
|
||||
return explanation
|
||||
|
||||
|
||||
class Confidences(collections.UserDict[str, float]):
|
||||
def __init__(
|
||||
self,
|
||||
probs: Dict[str, float],
|
||||
model: Model,
|
||||
text: str,
|
||||
max_ngram_length: int,
|
||||
):
|
||||
collections.UserDict.__init__(self, probs)
|
||||
self.model = model
|
||||
self.text = text
|
||||
self.max_ngram_length = max_ngram_length
|
||||
|
||||
|
||||
class TransparentConfidences(Confidences):
|
||||
def __init__(
|
||||
self,
|
||||
probs: Dict[str, float],
|
||||
explanation: Explanation,
|
||||
model: Model,
|
||||
text: str,
|
||||
max_ngram_length: int,
|
||||
):
|
||||
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
||||
self.explanation = explanation
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -23,7 +91,7 @@ class Model:
|
|||
|
||||
def confidence(
|
||||
self, text: str, max_ngram_length: int, transparent: bool = False
|
||||
) -> Dict[str, float]:
|
||||
) -> Confidences:
|
||||
"""Classify text with confidence.
|
||||
|
||||
Parameters
|
||||
|
@ -56,7 +124,7 @@ class Model:
|
|||
|
||||
if transparent:
|
||||
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
|
||||
log = []
|
||||
log: Log = []
|
||||
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
|
||||
|
@ -71,7 +139,13 @@ class Model:
|
|||
)
|
||||
|
||||
if transparent:
|
||||
log.append([token_map[word], weight, unweighted_numbers])
|
||||
log.append(
|
||||
(
|
||||
token_map[word],
|
||||
weight,
|
||||
unweighted_numbers,
|
||||
)
|
||||
)
|
||||
|
||||
for category, value in enumerate(weighted_numbers):
|
||||
try:
|
||||
|
@ -88,19 +162,7 @@ class Model:
|
|||
}
|
||||
|
||||
if transparent:
|
||||
explanation = {}
|
||||
for word, weight, word_probs in log:
|
||||
if word in explanation:
|
||||
explanation[word]["count"] += 1
|
||||
else:
|
||||
explanation[word] = {
|
||||
"weight": weight,
|
||||
"probabilities": {
|
||||
name: word_probs[index]
|
||||
for index, name in enumerate(self.names)
|
||||
},
|
||||
"count": 1,
|
||||
}
|
||||
explanation = convert_log(log, self.names)
|
||||
|
||||
return TransparentConfidences(
|
||||
probs, explanation, self, text, max_ngram_length
|
||||
|
@ -141,20 +203,6 @@ class Model:
|
|||
)
|
||||
|
||||
|
||||
class Confidences(collections.UserDict):
|
||||
def __init__(self, probs, model, text, max_ngram_length):
|
||||
collections.UserDict.__init__(self, probs)
|
||||
self.model = model
|
||||
self.text = text
|
||||
self.max_ngram_length = max_ngram_length
|
||||
|
||||
|
||||
class TransparentConfidences(Confidences):
|
||||
def __init__(self, probs, explanation, model, text, max_ngram_length):
|
||||
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
||||
self.explanation = explanation
|
||||
|
||||
|
||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||
prefix = encoded_model.read(14)
|
||||
if prefix != b"GPTC model v6\n":
|
||||
|
|
|
@ -39,7 +39,7 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
|
|||
return math.sqrt(_mean(squared_deviations))
|
||||
|
||||
|
||||
def weight(numbers: Sequence[float]) -> List[float]:
|
||||
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
|
||||
standard_deviation = _standard_deviation(numbers)
|
||||
weight = standard_deviation * 2
|
||||
weighted_numbers = [i * weight for i in numbers]
|
||||
|
|
Loading…
Reference in New Issue
Block a user