parent
548d670960
commit
25192ffddf
|
@ -63,6 +63,12 @@ returned.
|
|||
|
||||
For information about `max_ngram_length`, see section "Ngrams."
|
||||
|
||||
### `Model.get(token)`
|
||||
|
||||
Return a confidence dict for the given token or ngram. This function is very
|
||||
similar to `Model.confidence()`, except it treats the input as a single token
|
||||
or ngram.
|
||||
|
||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
||||
|
||||
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
||||
|
|
|
@ -40,9 +40,7 @@ class Model:
|
|||
model = self.weights
|
||||
|
||||
tokens = gptc.tokenizer.hash(
|
||||
gptc.tokenizer.tokenize(
|
||||
text, min(max_ngram_length, self.max_ngram_length)
|
||||
)
|
||||
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
|
||||
)
|
||||
numbered_probs: Dict[int, float] = {}
|
||||
for word in tokens:
|
||||
|
@ -64,6 +62,18 @@ class Model:
|
|||
}
|
||||
return probs
|
||||
|
||||
def get(self, token):
|
||||
try:
|
||||
weights = self.weights[
|
||||
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
|
||||
]
|
||||
except KeyError:
|
||||
return {}
|
||||
return {
|
||||
category: weights[index] / 65535
|
||||
for index, category in enumerate(self.names)
|
||||
}
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
out = b"GPTC model v4\n"
|
||||
out += (
|
||||
|
@ -112,9 +122,9 @@ def deserialize(encoded_model: bytes) -> Model:
|
|||
except KeyError:
|
||||
raise InvalidModelError()
|
||||
|
||||
if not (
|
||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
||||
) or not all([isinstance(name, str) for name in names]):
|
||||
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
|
||||
[isinstance(name, str) for name in names]
|
||||
):
|
||||
raise InvalidModelError()
|
||||
|
||||
weight_code_length = 6 + 2 * len(names)
|
||||
|
|
Loading…
Reference in New Issue
Block a user