Compare commits

..

8 Commits

5 changed files with 123 additions and 85 deletions

View File

@ -43,13 +43,13 @@ example of the format. Any exceptions will be printed to stderr.
## Library
### `Model.serialize()`
### `Model.serialize(file)`
Returns a `bytes` representing the model.
Write binary data representing the model to `file`.
### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
Deserialize a `Model` from a file containing data from `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)`
@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
Compile a raw model (as a list, not JSON) and return the compiled model (as a
`gptc.Model` object).
@ -79,6 +79,26 @@ For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model.
The hash algorithm should be left as the default, which may change with a minor
version update, but it can be changed by the application if needed. It is
stored in the model, so changing the algorithm does not affect compatibility.
The following algorithms are supported:
* `md5`
* `sha1`
* `sha224`
* `sha256`
* `sha384`
* `sha512`
* `sha3_224`
* `sha3_384`
* `sha3_256`
* `sha3_512`
* `shake_128`
* `shake_256`
* `blake2b`
* `blake2s`
### `gptc.pack(directory, print_exceptions=False)`
Pack the model in `directory` and return a tuple of the format:

View File

@ -66,12 +66,10 @@ def main() -> None:
with open(args.model, "r") as f:
model = json.load(f)
sys.stdout.buffer.write(
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
)
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f.read())
model = gptc.deserialize(f)
if sys.stdin.isatty():
text = input("Text to analyse: ")
@ -87,7 +85,7 @@ def main() -> None:
print(json.dumps(probabilities))
elif args.subparser_name == "check":
with open(args.model, "rb") as f:
model = gptc.deserialize(f.read())
model = gptc.deserialize(f)
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -9,6 +9,7 @@ def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
@ -27,23 +28,24 @@ def compile(
"""
categories: Dict[str, List[int]] = {}
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
category = portion["category"]
try:
categories[category] += text
except KeyError:
categories[category] = text
word_counts: Dict[int, Dict[str, int]] = {}
if not category in names:
names.append(category)
names = list(categories.keys())
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
for category, text in categories.items():
for word in text:
if word in word_counts:
try:
@ -53,27 +55,19 @@ def compile(
else:
word_counts[word] = {category: 1}
category_lengths = {
category: len(text) for category, text in categories.items()
}
word_weights: Dict[int, Dict[str, float]] = {
word: {
category: value / category_lengths[category]
for category, value in values.items()
}
for word, values in word_counts.items()
if sum(values.values()) >= min_count
}
model: Dict[int, List[int]] = {}
for word, weights in word_weights.items():
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return gptc.model.Model(model, names, max_ngram_length)
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)

View File

@ -3,7 +3,7 @@
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
import json
@ -13,10 +13,12 @@ class Model:
weights: Dict[int, List[int]],
names: List[str],
max_ngram_length: int,
hash_algorithm: str,
):
self.weights = weights
self.names = names
self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
"""Classify text with confidence.
@ -40,7 +42,10 @@ class Model:
model = self.weights
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
),
self.hash_algorithm,
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -74,42 +79,39 @@ class Model:
for index, category in enumerate(self.names)
}
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (
def serialize(self, file: BinaryIO):
file.write(b"GPTC model v5\n")
file.write(
json.dumps(
{
"names": self.names,
"max_ngram_length": self.max_ngram_length,
"has_emoji": True,
# Due to an oversight in development, version 3.0.0 still
# had the code used to make emoji support optional, even
# though the `emoji` library was made a hard dependency.
# Part of this code checked whether or not the model
# supports emoji; deserialization would not work in 3.0.0
# if the model was compiled without this field. Emoji are
# always supported with 3.0.0 and newer when GPTC has been
# installed correctly, so this value should always be True.
# Related: #11
"hash_algorithm": self.hash_algorithm,
}
).encode("utf-8")
+ b"\n"
)
for word, weights in self.weights.items():
out += word.to_bytes(6, "big") + b"".join(
[weight.to_bytes(2, "big") for weight in weights]
file.write(
word.to_bytes(6, "big")
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
)
return out
def deserialize(encoded_model: bytes) -> Model:
try:
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
except ValueError:
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
if prefix != b"GPTC model v5\n":
raise InvalidModelError()
if prefix != b"GPTC model v4":
raise InvalidModelError()
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
elif byte == b"":
raise InvalidModelError()
else:
config_json += byte
try:
config = json.loads(config_json.decode("utf-8"))
@ -119,30 +121,29 @@ def deserialize(encoded_model: bytes) -> Model:
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError:
raise InvalidModelError()
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
[isinstance(name, str) for name in names]
):
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)
if len(encoded_weights) % weight_code_length != 0:
raise InvalidModelError()
weights: Dict[int : List[int]] = {}
weight_codes = [
encoded_weights[x : x + weight_code_length]
for x in range(0, len(encoded_weights), weight_code_length)
]
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
elif len(code) != weight_code_length:
raise InvalidModelError()
weights = {
int.from_bytes(code[:6], "big"): [
weights[int.from_bytes(code[:6], "big")] = [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
for code in weight_codes
}
return Model(weights, names, max_ngram_length)
return Model(weights, names, max_ngram_length, hash_algorithm)

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union
from typing import List, Union, Callable
import hashlib
import emoji
import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
text = text.lower()
text = unicodedata.normalize("NFKD", text).lower()
parts = []
highest_end = 0
for emoji_part in emoji.emoji_list(text):
@ -19,7 +20,12 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
tokens = [""]
for char in converted_text:
if char.isalpha() or char == "'":
if (
char.isalpha()
or char.isnumeric()
or char == "'"
or (char in ",." and (" " + tokens[-1])[-1].isnumeric())
):
tokens[-1] += char
elif emoji.is_emoji(char):
tokens.append(char)
@ -39,14 +45,33 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams
def hash_single(token: str) -> int:
def hash_single(token: str, hash_function: Callable) -> int:
return int.from_bytes(
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
hash_function(token.encode("utf-8")).digest()[:6], "big"
)
def hash(tokens: List[str]) -> List[int]:
return [hash_single(token) for token in tokens]
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
if hash_algorithm in {
"sha224",
"md5",
"sha512",
"sha3_256",
"blake2s",
"sha3_224",
"sha1",
"sha256",
"sha384",
"shake_256",
"blake2b",
"sha3_512",
"shake_128",
"sha3_384",
}:
hash_function = getattr(hashlib, hash_algorithm)
return [hash_single(token, hash_function) for token in tokens]
else:
raise ValueError("not a valid hash function: " + hash_algorithm)
def normalize(text: str) -> str: