From 7f68dc6fc6961495cc5d4ca9b80708adfbc95181 Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Sun, 16 Apr 2023 15:35:53 -0700 Subject: [PATCH] Add classification explanations Closes #17 --- gptc/model.py | 66 ++++++++++++++++++++++++++++++++++++++++++----- gptc/weighting.py | 2 +- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/gptc/model.py b/gptc/model.py index 3b04827..b9613d3 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -5,6 +5,7 @@ from gptc.exceptions import InvalidModelError import gptc.weighting from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO import json +import collections class Model: @@ -20,7 +21,9 @@ class Model: self.max_ngram_length = max_ngram_length self.hash_algorithm = hash_algorithm - def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]: + def confidence( + self, text: str, max_ngram_length: int, transparent: bool = False + ) -> Dict[str, float]: """Classify text with confidence. Parameters @@ -40,19 +43,36 @@ class Model: """ model = self.weights + max_ngram_length = min(self.max_ngram_length, max_ngram_length) + + raw_tokens = gptc.tokenizer.tokenize( + text, min(max_ngram_length, self.max_ngram_length) + ) tokens = gptc.tokenizer.hash( - gptc.tokenizer.tokenize( - text, min(max_ngram_length, self.max_ngram_length) - ), + raw_tokens, self.hash_algorithm, ) + + if transparent: + token_map = {tokens[i]: raw_tokens[i] for i in range(len(tokens))} + log = [] + numbered_probs: Dict[int, float] = {} + for word in tokens: try: - weighted_numbers = gptc.weighting.weight( - [i / 65535 for i in cast(List[float], model[word])] + unweighted_numbers = [ + i / 65535 for i in cast(List[float], model[word]) + ] + + weight, weighted_numbers = gptc.weighting.weight( + unweighted_numbers ) + + if transparent: + log.append([token_map[word], weight, unweighted_numbers]) + for category, value in enumerate(weighted_numbers): try: numbered_probs[category] += value @@ -60,12 +80,30 @@ class Model: numbered_probs[category] = value except KeyError: pass + total = sum(numbered_probs.values()) probs: Dict[str, float] = { self.names[category]: value / total for category, value in numbered_probs.items() } - return probs + + if transparent: + explanation = {} + for word, weight, word_probs in log: + if word in explanation: + explanation[word]["count"] += 1 + else: + explanation[word] = { + "weight": weight, + "probabilities": word_probs, + "count": 1, + } + + return TransparentConfidences( + probs, explanation, self, text, max_ngram_length + ) + else: + return Confidences(probs, self, text, max_ngram_length) def get(self, token: str) -> Dict[str, float]: try: @@ -100,6 +138,20 @@ class Model: ) +class Confidences(collections.UserDict): + def __init__(self, probs, model, text, max_ngram_length): + collections.UserDict.__init__(self, probs) + self.model = model + self.text = text + self.max_ngram_length = max_ngram_length + + +class TransparentConfidences(Confidences): + def __init__(self, probs, explanation, model, text, max_ngram_length): + Confidences.__init__(self, probs, model, text, max_ngram_length) + self.explanation = explanation + + def deserialize(encoded_model: BinaryIO) -> Model: prefix = encoded_model.read(14) if prefix != b"GPTC model v5\n": diff --git a/gptc/weighting.py b/gptc/weighting.py index 18d6ac2..e4b0870 100755 --- a/gptc/weighting.py +++ b/gptc/weighting.py @@ -43,4 +43,4 @@ def weight(numbers: Sequence[float]) -> List[float]: standard_deviation = _standard_deviation(numbers) weight = standard_deviation * 2 weighted_numbers = [i * weight for i in numbers] - return weighted_numbers + return weight, weighted_numbers