Compare commits

..

No commits in common. "f8dbc78b8231ae551f6ec61514644518190e5e21" and "a76c6d3da8dfc59aec9f0ba6916e2092443f3d72" have entirely different histories.

5 changed files with 86 additions and 124 deletions

View File

@ -43,13 +43,13 @@ example of the format. Any exceptions will be printed to stderr.
## Library ## Library
### `Model.serialize(file)` ### `Model.serialize()`
Write binary data representing the model to `file`. Returns a `bytes` representing the model.
### `gptc.deserialize(encoded_model)` ### `gptc.deserialize(encoded_model)`
Deserialize a `Model` from a file containing data from `Model.serialize()`. Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
### `Model.confidence(text, max_ngram_length)` ### `Model.confidence(text, max_ngram_length)`
@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very
similar to `Model.confidence()`, except it treats the input as a single token similar to `Model.confidence()`, except it treats the input as a single token
or ngram. or ngram.
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")` ### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
Compile a raw model (as a list, not JSON) and return the compiled model (as a Compile a raw model (as a list, not JSON) and return the compiled model (as a
`gptc.Model` object). `gptc.Model` object).
@ -79,26 +79,6 @@ For information about `max_ngram_length`, see section "Ngrams."
Words or ngrams used less than `min_count` times throughout the input text are Words or ngrams used less than `min_count` times throughout the input text are
excluded from the model. excluded from the model.
The hash algorithm should be left as the default, which may change with a minor
version update, but it can be changed by the application if needed. It is
stored in the model, so changing the algorithm does not affect compatibility.
The following algorithms are supported:
* `md5`
* `sha1`
* `sha224`
* `sha256`
* `sha384`
* `sha512`
* `sha3_224`
* `sha3_384`
* `sha3_256`
* `sha3_512`
* `shake_128`
* `shake_256`
* `blake2b`
* `blake2s`
### `gptc.pack(directory, print_exceptions=False)` ### `gptc.pack(directory, print_exceptions=False)`
Pack the model in `directory` and return a tuple of the format: Pack the model in `directory` and return a tuple of the format:

View File

@ -66,10 +66,12 @@ def main() -> None:
with open(args.model, "r") as f: with open(args.model, "r") as f:
model = json.load(f) model = json.load(f)
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer) sys.stdout.buffer.write(
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
)
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f) model = gptc.deserialize(f.read())
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input("Text to analyse: ") text = input("Text to analyse: ")
@ -85,7 +87,7 @@ def main() -> None:
print(json.dumps(probabilities)) print(json.dumps(probabilities))
elif args.subparser_name == "check": elif args.subparser_name == "check":
with open(args.model, "rb") as f: with open(args.model, "rb") as f:
model = gptc.deserialize(f) model = gptc.deserialize(f.read())
print(json.dumps(model.get(args.token))) print(json.dumps(model.get(args.token)))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -9,7 +9,6 @@ def compile(
raw_model: Iterable[Mapping[str, str]], raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1, max_ngram_length: int = 1,
min_count: int = 1, min_count: int = 1,
hash_algorithm: str = "sha256",
) -> gptc.model.Model: ) -> gptc.model.Model:
"""Compile a raw model. """Compile a raw model.
@ -28,24 +27,23 @@ def compile(
""" """
word_counts: Dict[int, Dict[str, int]] = {} categories: Dict[str, List[int]] = {}
category_lengths: Dict[str, int] = {}
names: List[str] = []
for portion in raw_model: for portion in raw_model:
text = gptc.tokenizer.hash( text = gptc.tokenizer.hash(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length), gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
hash_algorithm,
) )
category = portion["category"] category = portion["category"]
try:
categories[category] += text
except KeyError:
categories[category] = text
if not category in names: word_counts: Dict[int, Dict[str, int]] = {}
names.append(category)
category_lengths[category] = category_lengths.get(category, 0) + len( names = list(categories.keys())
text
)
for category, text in categories.items():
for word in text: for word in text:
if word in word_counts: if word in word_counts:
try: try:
@ -55,13 +53,21 @@ def compile(
else: else:
word_counts[word] = {category: 1} word_counts[word] = {category: 1}
model: Dict[int, List[int]] = {} category_lengths = {
for word, counts in word_counts.items(): category: len(text) for category, text in categories.items()
if sum(counts.values()) >= min_count:
weights = {
category: value / category_lengths[category]
for category, value in counts.items()
} }
word_weights: Dict[int, Dict[str, float]] = {
word: {
category: value / category_lengths[category]
for category, value in values.items()
}
for word, values in word_counts.items()
if sum(values.values()) >= min_count
}
model: Dict[int, List[int]] = {}
for word, weights in word_weights.items():
total = sum(weights.values()) total = sum(weights.values())
new_weights: List[int] = [] new_weights: List[int] = []
for category in names: for category in names:
@ -70,4 +76,4 @@ def compile(
) )
model[word] = new_weights model[word] = new_weights
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm) return gptc.model.Model(model, names, max_ngram_length)

View File

@ -3,7 +3,7 @@
import gptc.tokenizer import gptc.tokenizer
from gptc.exceptions import InvalidModelError from gptc.exceptions import InvalidModelError
import gptc.weighting import gptc.weighting
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO from typing import Iterable, Mapping, List, Dict, Union, cast
import json import json
@ -13,12 +13,10 @@ class Model:
weights: Dict[int, List[int]], weights: Dict[int, List[int]],
names: List[str], names: List[str],
max_ngram_length: int, max_ngram_length: int,
hash_algorithm: str,
): ):
self.weights = weights self.weights = weights
self.names = names self.names = names
self.max_ngram_length = max_ngram_length self.max_ngram_length = max_ngram_length
self.hash_algorithm = hash_algorithm
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]: def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
"""Classify text with confidence. """Classify text with confidence.
@ -42,10 +40,7 @@ class Model:
model = self.weights model = self.weights
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash(
gptc.tokenizer.tokenize( gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
text, min(max_ngram_length, self.max_ngram_length)
),
self.hash_algorithm,
) )
numbered_probs: Dict[int, float] = {} numbered_probs: Dict[int, float] = {}
for word in tokens: for word in tokens:
@ -79,39 +74,42 @@ class Model:
for index, category in enumerate(self.names) for index, category in enumerate(self.names)
} }
def serialize(self, file: BinaryIO): def serialize(self) -> bytes:
file.write(b"GPTC model v5\n") out = b"GPTC model v4\n"
file.write( out += (
json.dumps( json.dumps(
{ {
"names": self.names, "names": self.names,
"max_ngram_length": self.max_ngram_length, "max_ngram_length": self.max_ngram_length,
"hash_algorithm": self.hash_algorithm, "has_emoji": True,
# Due to an oversight in development, version 3.0.0 still
# had the code used to make emoji support optional, even
# though the `emoji` library was made a hard dependency.
# Part of this code checked whether or not the model
# supports emoji; deserialization would not work in 3.0.0
# if the model was compiled without this field. Emoji are
# always supported with 3.0.0 and newer when GPTC has been
# installed correctly, so this value should always be True.
# Related: #11
} }
).encode("utf-8") ).encode("utf-8")
+ b"\n" + b"\n"
) )
for word, weights in self.weights.items(): for word, weights in self.weights.items():
file.write( out += word.to_bytes(6, "big") + b"".join(
word.to_bytes(6, "big") [weight.to_bytes(2, "big") for weight in weights]
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
) )
return out
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: bytes) -> Model:
prefix = encoded_model.read(14) try:
if prefix != b"GPTC model v5\n": prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
except ValueError:
raise InvalidModelError() raise InvalidModelError()
config_json = b"" if prefix != b"GPTC model v4":
while True:
byte = encoded_model.read(1)
if byte == b"\n":
break
elif byte == b"":
raise InvalidModelError() raise InvalidModelError()
else:
config_json += byte
try: try:
config = json.loads(config_json.decode("utf-8")) config = json.loads(config_json.decode("utf-8"))
@ -121,29 +119,30 @@ def deserialize(encoded_model: BinaryIO) -> Model:
try: try:
names = config["names"] names = config["names"]
max_ngram_length = config["max_ngram_length"] max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError: except KeyError:
raise InvalidModelError() raise InvalidModelError()
if not ( if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
isinstance(names, list) and isinstance(max_ngram_length, int) [isinstance(name, str) for name in names]
) or not all([isinstance(name, str) for name in names]): ):
raise InvalidModelError() raise InvalidModelError()
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)
weights: Dict[int : List[int]] = {} if len(encoded_weights) % weight_code_length != 0:
while True:
code = encoded_model.read(weight_code_length)
if not code:
break
elif len(code) != weight_code_length:
raise InvalidModelError() raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [ weight_codes = [
encoded_weights[x : x + weight_code_length]
for x in range(0, len(encoded_weights), weight_code_length)
]
weights = {
int.from_bytes(code[:6], "big"): [
int.from_bytes(value, "big") int.from_bytes(value, "big")
for value in [code[x : x + 2] for x in range(6, len(code), 2)] for value in [code[x : x + 2] for x in range(6, len(code), 2)]
] ]
for code in weight_codes
}
return Model(weights, names, max_ngram_length, hash_algorithm) return Model(weights, names, max_ngram_length)

View File

@ -1,13 +1,12 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union, Callable from typing import List, Union
import hashlib import hashlib
import emoji import emoji
import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
text = unicodedata.normalize("NFKD", text).lower() text = text.lower()
parts = [] parts = []
highest_end = 0 highest_end = 0
for emoji_part in emoji.emoji_list(text): for emoji_part in emoji.emoji_list(text):
@ -20,12 +19,7 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
tokens = [""] tokens = [""]
for char in converted_text: for char in converted_text:
if ( if char.isalpha() or char == "'":
char.isalpha()
or char.isnumeric()
or char == "'"
or (char in ",." and (" " + tokens[-1])[-1].isnumeric())
):
tokens[-1] += char tokens[-1] += char
elif emoji.is_emoji(char): elif emoji.is_emoji(char):
tokens.append(char) tokens.append(char)
@ -45,33 +39,14 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
return ngrams return ngrams
def hash_single(token: str, hash_function: Callable) -> int: def hash_single(token: str) -> int:
return int.from_bytes( return int.from_bytes(
hash_function(token.encode("utf-8")).digest()[:6], "big" hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
) )
def hash(tokens: List[str], hash_algorithm: str) -> List[int]: def hash(tokens: List[str]) -> List[int]:
if hash_algorithm in { return [hash_single(token) for token in tokens]
"sha224",
"md5",
"sha512",
"sha3_256",
"blake2s",
"sha3_224",
"sha1",
"sha256",
"sha384",
"shake_256",
"blake2b",
"sha3_512",
"shake_128",
"sha3_384",
}:
hash_function = getattr(hashlib, hash_algorithm)
return [hash_single(token, hash_function) for token in tokens]
else:
raise ValueError("not a valid hash function: " + hash_algorithm)
def normalize(text: str) -> str: def normalize(text: str) -> str: