Compare commits
No commits in common. "25192ffddf055657a1fb5ce1b497bd2da02478d9" and "fc4665bb9e640d36a67553e3826a52632a4a6e98" have entirely different histories.
25192ffddf
...
fc4665bb9e
|
@ -63,12 +63,6 @@ returned.
|
||||||
|
|
||||||
For information about `max_ngram_length`, see section "Ngrams."
|
For information about `max_ngram_length`, see section "Ngrams."
|
||||||
|
|
||||||
### `Model.get(token)`
|
|
||||||
|
|
||||||
Return a confidence dict for the given token or ngram. This function is very
|
|
||||||
similar to `Model.confidence()`, except it treats the input as a single token
|
|
||||||
or ngram.
|
|
||||||
|
|
||||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
||||||
|
|
||||||
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
||||||
|
|
|
@ -6,7 +6,6 @@ from gptc.compiler import compile as compile
|
||||||
from gptc.classifier import Classifier as Classifier
|
from gptc.classifier import Classifier as Classifier
|
||||||
from gptc.pack import pack as pack
|
from gptc.pack import pack as pack
|
||||||
from gptc.model import Model as Model, deserialize as deserialize
|
from gptc.model import Model as Model, deserialize as deserialize
|
||||||
from gptc.tokenizer import normalize as normalize
|
|
||||||
from gptc.exceptions import (
|
from gptc.exceptions import (
|
||||||
GPTCError as GPTCError,
|
GPTCError as GPTCError,
|
||||||
ModelError as ModelError,
|
ModelError as ModelError,
|
||||||
|
|
|
@ -74,12 +74,14 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
text = sys.stdin.read()
|
text = sys.stdin.read()
|
||||||
|
|
||||||
|
probabilities = model.confidence(text, args.max_ngram_length)
|
||||||
|
|
||||||
if args.category:
|
if args.category:
|
||||||
classifier = gptc.Classifier(model, args.max_ngram_length)
|
try:
|
||||||
print(classifier.classify(text))
|
print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0])
|
||||||
|
except IndexError:
|
||||||
|
print(None)
|
||||||
else:
|
else:
|
||||||
probabilities = model.confidence(text, args.max_ngram_length)
|
|
||||||
print(json.dumps(probabilities))
|
print(json.dumps(probabilities))
|
||||||
else:
|
else:
|
||||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||||
|
|
|
@ -40,7 +40,9 @@ class Model:
|
||||||
model = self.weights
|
model = self.weights
|
||||||
|
|
||||||
tokens = gptc.tokenizer.hash(
|
tokens = gptc.tokenizer.hash(
|
||||||
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
|
gptc.tokenizer.tokenize(
|
||||||
|
text, min(max_ngram_length, self.max_ngram_length)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
for word in tokens:
|
for word in tokens:
|
||||||
|
@ -62,18 +64,6 @@ class Model:
|
||||||
}
|
}
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def get(self, token):
|
|
||||||
try:
|
|
||||||
weights = self.weights[
|
|
||||||
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))
|
|
||||||
]
|
|
||||||
except KeyError:
|
|
||||||
return {}
|
|
||||||
return {
|
|
||||||
category: weights[index] / 65535
|
|
||||||
for index, category in enumerate(self.names)
|
|
||||||
}
|
|
||||||
|
|
||||||
def serialize(self) -> bytes:
|
def serialize(self) -> bytes:
|
||||||
out = b"GPTC model v4\n"
|
out = b"GPTC model v4\n"
|
||||||
out += (
|
out += (
|
||||||
|
@ -122,9 +112,9 @@ def deserialize(encoded_model: bytes) -> Model:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
|
if not (
|
||||||
[isinstance(name, str) for name in names]
|
isinstance(names, list) and isinstance(max_ngram_length, int)
|
||||||
):
|
) or not all([isinstance(name, str) for name in names]):
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
weight_code_length = 6 + 2 * len(names)
|
weight_code_length = 6 + 2 * len(names)
|
||||||
|
|
|
@ -39,15 +39,10 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
return ngrams
|
return ngrams
|
||||||
|
|
||||||
|
|
||||||
def hash_single(token: str) -> int:
|
def hash(tokens: List[str]) -> List[int]:
|
||||||
return int.from_bytes(
|
return [
|
||||||
|
int.from_bytes(
|
||||||
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
|
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
|
||||||
)
|
)
|
||||||
|
for token in tokens
|
||||||
|
]
|
||||||
def hash(tokens: List[str]) -> List[int]:
|
|
||||||
return [hash_single(token) for token in tokens]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(text: str) -> str:
|
|
||||||
return " ".join(tokenize(text, 1))
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user