diff --git a/gptc/model.py b/gptc/model.py index 54695f4..8270b74 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -3,11 +3,79 @@ import gptc.tokenizer from gptc.exceptions import InvalidModelError import gptc.weighting -from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO +from typing import ( + Iterable, + Mapping, + List, + Dict, + Union, + cast, + BinaryIO, + Any, + Tuple, + TypedDict, +) import json import collections +class ExplanationEntry(TypedDict): + weight: float + probabilities: Dict[str, float] + count: int + + +Explanation = Dict[ + str, + ExplanationEntry, +] + +Log = List[Tuple[str, float, List[float]]] + + +def convert_log(log: Log, names: List[str]) -> Explanation: + explanation: Explanation = {} + for word2, weight, word_probs in log: + if word2 in explanation: + explanation[word2]["count"] += 1 + else: + explanation[word2] = { + "weight": weight, + "probabilities": { + name: word_probs[index] for index, name in enumerate(names) + }, + "count": 1, + } + return explanation + + +class Confidences(collections.UserDict[str, float]): + def __init__( + self, + probs: Dict[str, float], + model: Model, + text: str, + max_ngram_length: int, + ): + collections.UserDict.__init__(self, probs) + self.model = model + self.text = text + self.max_ngram_length = max_ngram_length + + +class TransparentConfidences(Confidences): + def __init__( + self, + probs: Dict[str, float], + explanation: Explanation, + model: Model, + text: str, + max_ngram_length: int, + ): + Confidences.__init__(self, probs, model, text, max_ngram_length) + self.explanation = explanation + + class Model: def __init__( self, @@ -23,7 +91,7 @@ class Model: def confidence( self, text: str, max_ngram_length: int, transparent: bool = False - ) -> Dict[str, float]: + ) -> Confidences: """Classify text with confidence. Parameters @@ -56,7 +124,7 @@ class Model: if transparent: token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))} - log = [] + log: Log = [] numbered_probs: Dict[int, float] = {} @@ -71,7 +139,13 @@ class Model: ) if transparent: - log.append([token_map[word], weight, unweighted_numbers]) + log.append( + ( + token_map[word], + weight, + unweighted_numbers, + ) + ) for category, value in enumerate(weighted_numbers): try: @@ -88,19 +162,7 @@ class Model: } if transparent: - explanation = {} - for word, weight, word_probs in log: - if word in explanation: - explanation[word]["count"] += 1 - else: - explanation[word] = { - "weight": weight, - "probabilities": { - name: word_probs[index] - for index, name in enumerate(self.names) - }, - "count": 1, - } + explanation = convert_log(log, self.names) return TransparentConfidences( probs, explanation, self, text, max_ngram_length @@ -141,20 +203,6 @@ class Model: ) -class Confidences(collections.UserDict): - def __init__(self, probs, model, text, max_ngram_length): - collections.UserDict.__init__(self, probs) - self.model = model - self.text = text - self.max_ngram_length = max_ngram_length - - -class TransparentConfidences(Confidences): - def __init__(self, probs, explanation, model, text, max_ngram_length): - Confidences.__init__(self, probs, model, text, max_ngram_length) - self.explanation = explanation - - def deserialize(encoded_model: BinaryIO) -> Model: prefix = encoded_model.read(14) if prefix != b"GPTC model v6\n": diff --git a/gptc/weighting.py b/gptc/weighting.py index e4b0870..d033af1 100755 --- a/gptc/weighting.py +++ b/gptc/weighting.py @@ -39,7 +39,7 @@ def _standard_deviation(numbers: Sequence[float]) -> float: return math.sqrt(_mean(squared_deviations)) -def weight(numbers: Sequence[float]) -> List[float]: +def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]: standard_deviation = _standard_deviation(numbers) weight = standard_deviation * 2 weighted_numbers = [i * weight for i in numbers]