From 25192ffddf055657a1fb5ce1b497bd2da02478d9 Mon Sep 17 00:00:00 2001 From: Samuel Sloniker Date: Sat, 26 Nov 2022 18:17:02 -0800 Subject: [PATCH] Add ability to look up individual token Closes #10 --- README.md | 6 ++++++ gptc/model.py | 22 ++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d000ff0..a22febc 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,12 @@ returned. For information about `max_ngram_length`, see section "Ngrams." +### `Model.get(token)` + +Return a confidence dict for the given token or ngram. This function is very +similar to `Model.confidence()`, except it treats the input as a single token +or ngram. + ### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` Compile a raw model (as a list, not JSON) and return the compiled model (as a diff --git a/gptc/model.py b/gptc/model.py index e1772ab..634d900 100644 --- a/gptc/model.py +++ b/gptc/model.py @@ -40,9 +40,7 @@ class Model: model = self.weights tokens = gptc.tokenizer.hash( - gptc.tokenizer.tokenize( - text, min(max_ngram_length, self.max_ngram_length) - ) + gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length)) ) numbered_probs: Dict[int, float] = {} for word in tokens: @@ -64,6 +62,18 @@ class Model: } return probs + def get(self, token): + try: + weights = self.weights[ + gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token)) + ] + except KeyError: + return {} + return { + category: weights[index] / 65535 + for index, category in enumerate(self.names) + } + def serialize(self) -> bytes: out = b"GPTC model v4\n" out += ( @@ -112,9 +122,9 @@ def deserialize(encoded_model: bytes) -> Model: except KeyError: raise InvalidModelError() - if not ( - isinstance(names, list) and isinstance(max_ngram_length, int) - ) or not all([isinstance(name, str) for name in names]): + if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all( + [isinstance(name, str) for name in names] + ): raise InvalidModelError() weight_code_length = 6 + 2 * len(names)