Compare commits

...

4 Commits

Author SHA1 Message Date
25192ffddf
Add ability to look up individual token
Closes #10
2022-11-26 18:17:02 -08:00
548d670960
Use Classifier for --category 2022-11-26 17:50:26 -08:00
b3a43150d8
Split hash function 2022-11-26 17:42:42 -08:00
08437a2696
Add normalize() 2022-11-26 17:17:28 -08:00
5 changed files with 37 additions and 17 deletions

View File

@ -63,6 +63,12 @@ returned.
For information about `max_ngram_length`, see section "Ngrams."
### `Model.get(token)`
Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a

View File

@ -6,6 +6,7 @@ from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize
from gptc.exceptions import (
GPTCError as GPTCError,
ModelError as ModelError,

View File

@ -74,14 +74,12 @@ def main() -> None:
else:
text = sys.stdin.read()
probabilities = model.confidence(text, args.max_ngram_length)
if args.category:
try:
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
except IndexError:
print(None)
classifier = gptc.Classifier(model, args.max_ngram_length)
print(classifier.classify(text))
else:
probabilities = model.confidence(text, args.max_ngram_length)
print(json.dumps(probabilities))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -40,9 +40,7 @@ class Model:
model = self.weights
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -64,6 +62,18 @@ class Model:
}
return probs
def get(self, token):
try:
weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
]
except KeyError:
return {}
return {
category: weights[index] / 65535
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (
@ -112,9 +122,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError:
raise InvalidModelError()
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
[isinstance(name, str) for name in names]
):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)

View File

@ -39,10 +39,15 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams
def hash(tokens: List[str]) -> List[int]:
return [
int.from_bytes(
def hash_single(token: str) -> int:
return int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
)
for token in tokens
]
def hash(tokens: List[str]) -> List[int]:
return [hash_single(token) for token in tokens]
def normalize(text: str) -> str:
return " ".join(tokenize(text, 1))