Compare commits

..

No commits in common. "25192ffddf055657a1fb5ce1b497bd2da02478d9" and "fc4665bb9e640d36a67553e3826a52632a4a6e98" have entirely different histories.

5 changed files with 17 additions and 37 deletions

View File

@ -63,12 +63,6 @@ returned.
For information about `max_ngram_length`, see section "Ngrams." For information about `max_ngram_length`, see section "Ngrams."
### `Model.get(token)`
Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)` ### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a Compile a raw model (as a list, not JSON) and return the compiled model (as a

View File

@ -6,7 +6,6 @@ from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError as GPTCError, GPTCError as GPTCError,
ModelError as ModelError, ModelError as ModelError,

View File

@ -74,12 +74,14 @@ def main() -> None:
else: else:
text = sys.stdin.read() text = sys.stdin.read()
probabilities = model.confidence(text, args.max_ngram_length)
if args.category: if args.category:
classifier = gptc.Classifier(model, args.max_ngram_length) try:
print(classifier.classify(text)) print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
except IndexError:
print(None)
else: else:
probabilities = model.confidence(text, args.max_ngram_length)
print(json.dumps(probabilities)) print(json.dumps(probabilities))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -40,7 +40,9 @@ class Model:
model = self.weights model = self.weights
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length)) gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
) )
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
@ -62,18 +64,6 @@ class Model:
} }
return probs return probs
def get(self, token):
try:
weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
]
except KeyError:
return {}
return {
category: weights[index] / 65535
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes: def serialize(self) -> bytes:
out = b"GPTC model v4\n" out = b"GPTC model v4\n"
out += ( out += (
@ -122,9 +112,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError: except KeyError:
raise InvalidModelError() raise InvalidModelError()
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all( if not (
[isinstance(name, str) for name in names] isinstance(names, list) and isinstance(max_ngram_length, int)
): ) or not all([isinstance(name, str) for name in names]):
raise InvalidModelError() raise InvalidModelError()
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)

View File

@ -39,15 +39,10 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams return ngrams
def hash_single(token: str) -> int:
return int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
)
def hash(tokens: List[str]) -> List[int]: def hash(tokens: List[str]) -> List[int]:
return [hash_single(token) for token in tokens] return [
int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
def normalize(text: str) -> str: )
return " ".join(tokenize(text, 1)) for token in tokens
]