Compare commits

..

No commits in common. "f8dbc78b8231ae551f6ec61514644518190e5e21" and "a76c6d3da8dfc59aec9f0ba6916e2092443f3d72" have entirely different histories.

5 changed files with 86 additions and 124 deletions

View File

@ -43,13 +43,13 @@ example of the format. Any exceptions will be printed to stderr.
## Library
### `Model.serialize(file)`
### `Model.serialize()`
Write binary data representing the model to `file`.
Returns a `bytes` representing the model.
### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a file containing data from `Model.serialize()`.
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)`
@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token
or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a
`gptc.Model` object).
@ -79,26 +79,6 @@ For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model.
The hash algorithm should be left as the default, which may change with a minor
version update, but it can be changed by the application if needed. It is
stored in the model, so changing the algorithm does not affect compatibility.
The following algorithms are supported:
* `md5`
* `sha1`
* `sha224`
* `sha256`
* `sha384`
* `sha512`
* `sha3_224`
* `sha3_384`
* `sha3_256`
* `sha3_512`
* `shake_128`
* `shake_256`
* `blake2b`
* `blake2s`
### `gptc.pack(directory, print_exceptions=False)`
Pack the model in `directory` and return a tuple of the format:

View File

@ -66,10 +66,12 @@ def main() -> None:
with open(args.model, "r") as f:
model = json.load(f)
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
sys.stdout.buffer.write(
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)
model = gptc.deserialize(f.read())
if sys.stdin.isatty():
text = input("Text to analyse: ")
@ -85,7 +87,7 @@ def main() -> None:
print(json.dumps(probabilities))
elif args.subparser_name == "check":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)
model = gptc.deserialize(f.read())
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -9,7 +9,6 @@ def compile(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model:
"""Compile a raw model.
@ -28,24 +27,23 @@ def compile(
"""
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
categories: Dict[str, List[int]] = {}
for portion in raw_model:
text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
)
category = portion["category"]
try:
categories[category] += text
except KeyError:
categories[category] = text
if not category in names:
names.append(category)
word_counts: Dict[int, Dict[str, int]] = {}
category_lengths[category] = category_lengths.get(category, 0) + len(
text
)
names = list(categories.keys())
for category, text in categories.items():
for word in text:
if word in word_counts:
try:
@ -55,19 +53,27 @@ def compile(
else:
word_counts[word] = {category: 1}
model: Dict[int, List[int]] = {}
for word, counts in word_counts.items():
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
}
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
category_lengths = {
category: len(text) for category, text in categories.items()
}
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
word_weights: Dict[int, Dict[str, float]] = {
word: {
category: value / category_lengths[category]
for category, value in values.items()
}
for word, values in word_counts.items()
if sum(values.values()) >= min_count
}
model: Dict[int, List[int]] = {}
for word, weights in word_weights.items():
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights
return gptc.model.Model(model, names, max_ngram_length)

View File

@ -3,7 +3,7 @@
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
from typing import Iterable, Mapping, List, Dict, Union, cast
import json
@ -13,12 +13,10 @@ class Model:
weights: Dict[int, List[int]],
names: List[str],
max_ngram_length: int,
hash_algorithm: str,
):
self.weights = weights
self.names = names
self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
"""Classify text with confidence.
@ -42,10 +40,7 @@ class Model:
model = self.weights
tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(
text, min(max_ngram_length, self.max_ngram_length)
),
self.hash_algorithm,
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
@ -79,39 +74,42 @@ class Model:
for index, category in enumerate(self.names)
}
def serialize(self, file: BinaryIO):
file.write(b"GPTC model v5\n")
file.write(
def serialize(self) -> bytes:
out = b"GPTC model v4\n"
out += (
json.dumps(
{
"names": self.names,
"max_ngram_length": self.max_ngram_length,
"hash_algorithm": self.hash_algorithm,
"has_emoji": True,
# Due to an oversight in development, version 3.0.0 still
# had the code used to make emoji support optional, even
# though the `emoji` library was made a hard dependency.
# Part of this code checked whether or not the model
# supports emoji; deserialization would not work in 3.0.0
# if the model was compiled without this field. Emoji are
# always supported with 3.0.0 and newer when GPTC has been
# installed correctly, so this value should always be True.
# Related: #11
}
).encode("utf-8")
+ b"\n"
)
for word, weights in self.weights.items():
file.write(
word.to_bytes(6, "big")
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
out += word.to_bytes(6, "big") + b"".join(
[weight.to_bytes(2, "big") for weight in weights]
)
return out
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
if prefix != b"GPTC model v5\n":
def deserialize(encoded_model: bytes) -> Model:
try:
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
except ValueError:
raise InvalidModelError()
config_json = b""
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
elif byte == b"":
raise InvalidModelError()
else:
config_json += byte
if prefix != b"GPTC model v4":
raise InvalidModelError()
try:
config = json.loads(config_json.decode("utf-8"))
@ -121,29 +119,30 @@ def deserialize(encoded_model: BinaryIO) -> Model:
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError:
raise InvalidModelError()
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
[isinstance(name, str) for name in names]
):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)
weights: Dict[int : List[int]] = {}
if len(encoded_weights) % weight_code_length != 0:
raise InvalidModelError()
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
elif len(code) != weight_code_length:
raise InvalidModelError()
weight_codes = [
encoded_weights[x : x + weight_code_length]
for x in range(0, len(encoded_weights), weight_code_length)
]
weights[int.from_bytes(code[:6], "big")] = [
weights = {
int.from_bytes(code[:6], "big"): [
int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
]
for code in weight_codes
}
return Model(weights, names, max_ngram_length, hash_algorithm)
return Model(weights, names, max_ngram_length)

View File

@ -1,13 +1,12 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union, Callable
from typing import List, Union
import hashlib
import emoji
import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
text = unicodedata.normalize("NFKD", text).lower()
text = text.lower()
parts = []
highest_end = 0
for emoji_part in emoji.emoji_list(text):
@ -20,12 +19,7 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
tokens = [""]
for char in converted_text:
if (
char.isalpha()
or char.isnumeric()
or char == "'"
or (char in ",." and (" " + tokens[-1])[-1].isnumeric())
):
if char.isalpha() or char == "'":
tokens[-1] += char
elif emoji.is_emoji(char):
tokens.append(char)
@ -45,33 +39,14 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams
def hash_single(token: str, hash_function: Callable) -> int:
def hash_single(token: str) -> int:
return int.from_bytes(
hash_function(token.encode("utf-8")).digest()[:6], "big"
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
)
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
if hash_algorithm in {
"sha224",
"md5",
"sha512",
"sha3_256",
"blake2s",
"sha3_224",
"sha1",
"sha256",
"sha384",
"shake_256",
"blake2b",
"sha3_512",
"shake_128",
"sha3_384",
}:
hash_function = getattr(hashlib, hash_algorithm)
return [hash_single(token, hash_function) for token in tokens]
else:
raise ValueError("not a valid hash function: " + hash_algorithm)
def hash(tokens: List[str]) -> List[int]:
return [hash_single(token) for token in tokens]
def normalize(text: str) -> str: