Add ability to look up individual token

Closes #10
This commit is contained in:
Samuel Sloniker 2022-11-26 18:17:02 -08:00
parent 548d670960
commit 25192ffddf
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
2 changed files with 22 additions and 6 deletions

View File

@ -63,6 +63,12 @@ returned.
For information about `max_ngram_length`, see section "Ngrams."
### `Model.get(token)`
Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a

View File

@ -40,9 +40,7 @@ class Model:
model = self.weights
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -64,6 +62,18 @@ class Model:
}
return probs
def get(self, token):
try:
weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
]
except KeyError:
return {}
return {
category: weights[index] / 65535
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (
@ -112,9 +122,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError:
raise InvalidModelError()
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
[isinstance(name, str) for name in names]
):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)