Compare commits

..

4 Commits

Author SHA1 Message Date
25192ffddf
Add ability to look up individual token
Closes #10
2022-11-26 18:17:02 -08:00
548d670960
Use Classifier for --category 2022-11-26 17:50:26 -08:00
b3a43150d8
Split hash function 2022-11-26 17:42:42 -08:00
08437a2696
Add normalize() 2022-11-26 17:17:28 -08:00
5 changed files with 37 additions and 17 deletions

View File

@ -63,6 +63,12 @@ returned.
For information about `max_ngram_length`, see section "Ngrams." For information about `max_ngram_length`, see section "Ngrams."
### `Model.get(token)`
Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` ### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a Compile a raw model (as a list, not JSON) and return the compiled model (as a

View File

@ -6,6 +6,7 @@ from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError as GPTCError, GPTCError as GPTCError,
ModelError as ModelError, ModelError as ModelError,

View File

@ -74,14 +74,12 @@ def main() -> None:
else: else:
text = sys.stdin.read() text = sys.stdin.read()
probabilities = model.confidence(text, args.max_ngram_length)
if args.category: if args.category:
try: classifier = gptc.Classifier(model, args.max_ngram_length)
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0]) print(classifier.classify(text))
except IndexError:
print(None)
else: else:
probabilities = model.confidence(text, args.max_ngram_length)
print(json.dumps(probabilities)) print(json.dumps(probabilities))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -40,9 +40,7 @@ class Model:
model = self.weights model = self.weights
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize( gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
text, min(max_ngram_length, self.max_ngram_length)
)
) )
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
@ -64,6 +62,18 @@ class Model:
} }
return probs return probs
def get(self, token):
try:
weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
]
except KeyError:
return {}
return {
category: weights[index] / 65535
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes: def serialize(self) -> bytes:
out = b"GPTC model v4\n" out = b"GPTC model v4\n"
out += ( out += (
@ -112,9 +122,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError: except KeyError:
raise InvalidModelError() raise InvalidModelError()
if not ( if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
isinstance(names, list) and isinstance(max_ngram_length, int) [isinstance(name, str) for name in names]
) or not all([isinstance(name, str) for name in names]): ):
raise InvalidModelError() raise InvalidModelError()
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)

View File

@ -39,10 +39,15 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams return ngrams
def hash_single(token: str) -> int:
return int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
)
def hash(tokens: List[str]) -> List[int]: def hash(tokens: List[str]) -> List[int]:
return [ return [hash_single(token) for token in tokens]
int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
) def normalize(text: str) -> str:
for token in tokens return " ".join(tokenize(text, 1))
]