Fix type annotations
This commit is contained in:
parent
2c3fc77ba6
commit
9513025e60
110
gptc/model.py
110
gptc/model.py
|
@ -3,11 +3,79 @@
|
||||||
import gptc.tokenizer
|
import gptc.tokenizer
|
||||||
from gptc.exceptions import InvalidModelError
|
from gptc.exceptions import InvalidModelError
|
||||||
import gptc.weighting
|
import gptc.weighting
|
||||||
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
from typing import (
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
List,
|
||||||
|
Dict,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
BinaryIO,
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
)
|
||||||
import json
|
import json
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
class ExplanationEntry(TypedDict):
|
||||||
|
weight: float
|
||||||
|
probabilities: Dict[str, float]
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
Explanation = Dict[
|
||||||
|
str,
|
||||||
|
ExplanationEntry,
|
||||||
|
]
|
||||||
|
|
||||||
|
Log = List[Tuple[str, float, List[float]]]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_log(log: Log, names: List[str]) -> Explanation:
|
||||||
|
explanation: Explanation = {}
|
||||||
|
for word2, weight, word_probs in log:
|
||||||
|
if word2 in explanation:
|
||||||
|
explanation[word2]["count"] += 1
|
||||||
|
else:
|
||||||
|
explanation[word2] = {
|
||||||
|
"weight": weight,
|
||||||
|
"probabilities": {
|
||||||
|
name: word_probs[index] for index, name in enumerate(names)
|
||||||
|
},
|
||||||
|
"count": 1,
|
||||||
|
}
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
class Confidences(collections.UserDict[str, float]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probs: Dict[str, float],
|
||||||
|
model: Model,
|
||||||
|
text: str,
|
||||||
|
max_ngram_length: int,
|
||||||
|
):
|
||||||
|
collections.UserDict.__init__(self, probs)
|
||||||
|
self.model = model
|
||||||
|
self.text = text
|
||||||
|
self.max_ngram_length = max_ngram_length
|
||||||
|
|
||||||
|
|
||||||
|
class TransparentConfidences(Confidences):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probs: Dict[str, float],
|
||||||
|
explanation: Explanation,
|
||||||
|
model: Model,
|
||||||
|
text: str,
|
||||||
|
max_ngram_length: int,
|
||||||
|
):
|
||||||
|
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
||||||
|
self.explanation = explanation
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -23,7 +91,7 @@ class Model:
|
||||||
|
|
||||||
def confidence(
|
def confidence(
|
||||||
self, text: str, max_ngram_length: int, transparent: bool = False
|
self, text: str, max_ngram_length: int, transparent: bool = False
|
||||||
) -> Dict[str, float]:
|
) -> Confidences:
|
||||||
"""Classify text with confidence.
|
"""Classify text with confidence.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -56,7 +124,7 @@ class Model:
|
||||||
|
|
||||||
if transparent:
|
if transparent:
|
||||||
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
|
token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))}
|
||||||
log = []
|
log: Log = []
|
||||||
|
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
|
|
||||||
|
@ -71,7 +139,13 @@ class Model:
|
||||||
)
|
)
|
||||||
|
|
||||||
if transparent:
|
if transparent:
|
||||||
log.append([token_map[word], weight, unweighted_numbers])
|
log.append(
|
||||||
|
(
|
||||||
|
token_map[word],
|
||||||
|
weight,
|
||||||
|
unweighted_numbers,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for category, value in enumerate(weighted_numbers):
|
for category, value in enumerate(weighted_numbers):
|
||||||
try:
|
try:
|
||||||
|
@ -88,19 +162,7 @@ class Model:
|
||||||
}
|
}
|
||||||
|
|
||||||
if transparent:
|
if transparent:
|
||||||
explanation = {}
|
explanation = convert_log(log, self.names)
|
||||||
for word, weight, word_probs in log:
|
|
||||||
if word in explanation:
|
|
||||||
explanation[word]["count"] += 1
|
|
||||||
else:
|
|
||||||
explanation[word] = {
|
|
||||||
"weight": weight,
|
|
||||||
"probabilities": {
|
|
||||||
name: word_probs[index]
|
|
||||||
for index, name in enumerate(self.names)
|
|
||||||
},
|
|
||||||
"count": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
return TransparentConfidences(
|
return TransparentConfidences(
|
||||||
probs, explanation, self, text, max_ngram_length
|
probs, explanation, self, text, max_ngram_length
|
||||||
|
@ -141,20 +203,6 @@ class Model:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Confidences(collections.UserDict):
|
|
||||||
def __init__(self, probs, model, text, max_ngram_length):
|
|
||||||
collections.UserDict.__init__(self, probs)
|
|
||||||
self.model = model
|
|
||||||
self.text = text
|
|
||||||
self.max_ngram_length = max_ngram_length
|
|
||||||
|
|
||||||
|
|
||||||
class TransparentConfidences(Confidences):
|
|
||||||
def __init__(self, probs, explanation, model, text, max_ngram_length):
|
|
||||||
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
|
||||||
self.explanation = explanation
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
prefix = encoded_model.read(14)
|
prefix = encoded_model.read(14)
|
||||||
if prefix != b"GPTC model v6\n":
|
if prefix != b"GPTC model v6\n":
|
||||||
|
|
|
@ -39,7 +39,7 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
|
||||||
return math.sqrt(_mean(squared_deviations))
|
return math.sqrt(_mean(squared_deviations))
|
||||||
|
|
||||||
|
|
||||||
def weight(numbers: Sequence[float]) -> List[float]:
|
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
|
||||||
standard_deviation = _standard_deviation(numbers)
|
standard_deviation = _standard_deviation(numbers)
|
||||||
weight = standard_deviation * 2
|
weight = standard_deviation * 2
|
||||||
weighted_numbers = [i * weight for i in numbers]
|
weighted_numbers = [i * weight for i in numbers]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user