Compare commits

..

No commits in common. "25192ffddf055657a1fb5ce1b497bd2da02478d9" and "fc4665bb9e640d36a67553e3826a52632a4a6e98" have entirely different histories.

5 changed files with 17 additions and 37 deletions

View File

@ -63,12 +63,6 @@ returned.
For information about `max_ngram_length`, see section "Ngrams."
### `Model.get(token)`
Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a

View File

@ -6,7 +6,6 @@ from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize
from gptc.exceptions import (
GPTCError as GPTCError,
ModelError as ModelError,

View File

@ -74,12 +74,14 @@ def main() -> None:
else:
text = sys.stdin.read()
probabilities = model.confidence(text, args.max_ngram_length)
if args.category:
classifier = gptc.Classifier(model, args.max_ngram_length)
print(classifier.classify(text))
try:
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
except IndexError:
print(None)
else:
probabilities = model.confidence(text, args.max_ngram_length)
print(json.dumps(probabilities))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -40,7 +40,9 @@ class Model:
model = self.weights
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
)
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -62,18 +64,6 @@ class Model:
}
return probs
def get(self, token):
try:
weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
]
except KeyError:
return {}
return {
category: weights[index] / 65535
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (
@ -122,9 +112,9 @@ def deserialize(encoded_model: bytes) -> Model:
except KeyError:
raise InvalidModelError()
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
[isinstance(name, str) for name in names]
):
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)

View File

@ -39,15 +39,10 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams
def hash_single(token: str) -> int:
return int.from_bytes(
def hash(tokens: List[str]) -> List[int]:
return [
int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
)
def hash(tokens: List[str]) -> List[int]:
return [hash_single(token) for token in tokens]
def normalize(text: str) -> str:
return " ".join(tokenize(text, 1))
for token in tokens
]