Compare commits
No commits in common. "f8dbc78b8231ae551f6ec61514644518190e5e21" and "a76c6d3da8dfc59aec9f0ba6916e2092443f3d72" have entirely different histories.
f8dbc78b82
...
a76c6d3da8
28
README.md
28
README.md
|
@ -43,13 +43,13 @@ example of the format. Any exceptions will be printed to stderr.
|
||||||
|
|
||||||
## Library
|
## Library
|
||||||
|
|
||||||
### `Model.serialize(file)`
|
### `Model.serialize()`
|
||||||
|
|
||||||
Write binary data representing the model to `file`.
|
Returns a `bytes` representing the model.
|
||||||
|
|
||||||
### `gptc.deserialize(encoded_model)`
|
### `gptc.deserialize(encoded_model)`
|
||||||
|
|
||||||
Deserialize a `Model` from a file containing data from `Model.serialize()`.
|
Deserialize a `Model` from a `bytes` returned by `Model.serialize()`.
|
||||||
|
|
||||||
### `Model.confidence(text, max_ngram_length)`
|
### `Model.confidence(text, max_ngram_length)`
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ Return a confidence dict for the given token or ngram. This function is very
|
||||||
similar to `Model.confidence()`, except it treats the input as a single token
|
similar to `Model.confidence()`, except it treats the input as a single token
|
||||||
or ngram.
|
or ngram.
|
||||||
|
|
||||||
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1, hash_algorithm="sha256")`
|
### `gptc.compile(raw_model, max_ngram_length=1, min_count=1)`
|
||||||
|
|
||||||
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
Compile a raw model (as a list, not JSON) and return the compiled model (as a
|
||||||
`gptc.Model` object).
|
`gptc.Model` object).
|
||||||
|
@ -79,26 +79,6 @@ For information about `max_ngram_length`, see section "Ngrams."
|
||||||
Words or ngrams used less than `min_count` times throughout the input text are
|
Words or ngrams used less than `min_count` times throughout the input text are
|
||||||
excluded from the model.
|
excluded from the model.
|
||||||
|
|
||||||
The hash algorithm should be left as the default, which may change with a minor
|
|
||||||
version update, but it can be changed by the application if needed. It is
|
|
||||||
stored in the model, so changing the algorithm does not affect compatibility.
|
|
||||||
The following algorithms are supported:
|
|
||||||
|
|
||||||
* `md5`
|
|
||||||
* `sha1`
|
|
||||||
* `sha224`
|
|
||||||
* `sha256`
|
|
||||||
* `sha384`
|
|
||||||
* `sha512`
|
|
||||||
* `sha3_224`
|
|
||||||
* `sha3_384`
|
|
||||||
* `sha3_256`
|
|
||||||
* `sha3_512`
|
|
||||||
* `shake_128`
|
|
||||||
* `shake_256`
|
|
||||||
* `blake2b`
|
|
||||||
* `blake2s`
|
|
||||||
|
|
||||||
### `gptc.pack(directory, print_exceptions=False)`
|
### `gptc.pack(directory, print_exceptions=False)`
|
||||||
|
|
||||||
Pack the model in `directory` and return a tuple of the format:
|
Pack the model in `directory` and return a tuple of the format:
|
||||||
|
|
|
@ -66,10 +66,12 @@ def main() -> None:
|
||||||
with open(args.model, "r") as f:
|
with open(args.model, "r") as f:
|
||||||
model = json.load(f)
|
model = json.load(f)
|
||||||
|
|
||||||
gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer)
|
sys.stdout.buffer.write(
|
||||||
|
gptc.compile(model, args.max_ngram_length, args.min_count).serialize()
|
||||||
|
)
|
||||||
elif args.subparser_name == "classify":
|
elif args.subparser_name == "classify":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as f:
|
||||||
model = gptc.deserialize(f)
|
model = gptc.deserialize(f.read())
|
||||||
|
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
text = input("Text to analyse: ")
|
text = input("Text to analyse: ")
|
||||||
|
@ -85,7 +87,7 @@ def main() -> None:
|
||||||
print(json.dumps(probabilities))
|
print(json.dumps(probabilities))
|
||||||
elif args.subparser_name == "check":
|
elif args.subparser_name == "check":
|
||||||
with open(args.model, "rb") as f:
|
with open(args.model, "rb") as f:
|
||||||
model = gptc.deserialize(f)
|
model = gptc.deserialize(f.read())
|
||||||
print(json.dumps(model.get(args.token)))
|
print(json.dumps(model.get(args.token)))
|
||||||
else:
|
else:
|
||||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||||
|
|
|
@ -9,7 +9,6 @@ def compile(
|
||||||
raw_model: Iterable[Mapping[str, str]],
|
raw_model: Iterable[Mapping[str, str]],
|
||||||
max_ngram_length: int = 1,
|
max_ngram_length: int = 1,
|
||||||
min_count: int = 1,
|
min_count: int = 1,
|
||||||
hash_algorithm: str = "sha256",
|
|
||||||
) -> gptc.model.Model:
|
) -> gptc.model.Model:
|
||||||
"""Compile a raw model.
|
"""Compile a raw model.
|
||||||
|
|
||||||
|
@ -28,24 +27,23 @@ def compile(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
word_counts: Dict[int, Dict[str, int]] = {}
|
categories: Dict[str, List[int]] = {}
|
||||||
category_lengths: Dict[str, int] = {}
|
|
||||||
names: List[str] = []
|
|
||||||
|
|
||||||
for portion in raw_model:
|
for portion in raw_model:
|
||||||
text = gptc.tokenizer.hash(
|
text = gptc.tokenizer.hash(
|
||||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
gptc.tokenizer.tokenize(portion["text"], max_ngram_length)
|
||||||
hash_algorithm,
|
|
||||||
)
|
)
|
||||||
category = portion["category"]
|
category = portion["category"]
|
||||||
|
try:
|
||||||
|
categories[category] += text
|
||||||
|
except KeyError:
|
||||||
|
categories[category] = text
|
||||||
|
|
||||||
if not category in names:
|
word_counts: Dict[int, Dict[str, int]] = {}
|
||||||
names.append(category)
|
|
||||||
|
|
||||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
names = list(categories.keys())
|
||||||
text
|
|
||||||
)
|
|
||||||
|
|
||||||
|
for category, text in categories.items():
|
||||||
for word in text:
|
for word in text:
|
||||||
if word in word_counts:
|
if word in word_counts:
|
||||||
try:
|
try:
|
||||||
|
@ -55,19 +53,27 @@ def compile(
|
||||||
else:
|
else:
|
||||||
word_counts[word] = {category: 1}
|
word_counts[word] = {category: 1}
|
||||||
|
|
||||||
model: Dict[int, List[int]] = {}
|
category_lengths = {
|
||||||
for word, counts in word_counts.items():
|
category: len(text) for category, text in categories.items()
|
||||||
if sum(counts.values()) >= min_count:
|
}
|
||||||
weights = {
|
|
||||||
category: value / category_lengths[category]
|
|
||||||
for category, value in counts.items()
|
|
||||||
}
|
|
||||||
total = sum(weights.values())
|
|
||||||
new_weights: List[int] = []
|
|
||||||
for category in names:
|
|
||||||
new_weights.append(
|
|
||||||
round((weights.get(category, 0) / total) * 65535)
|
|
||||||
)
|
|
||||||
model[word] = new_weights
|
|
||||||
|
|
||||||
return gptc.model.Model(model, names, max_ngram_length, hash_algorithm)
|
word_weights: Dict[int, Dict[str, float]] = {
|
||||||
|
word: {
|
||||||
|
category: value / category_lengths[category]
|
||||||
|
for category, value in values.items()
|
||||||
|
}
|
||||||
|
for word, values in word_counts.items()
|
||||||
|
if sum(values.values()) >= min_count
|
||||||
|
}
|
||||||
|
|
||||||
|
model: Dict[int, List[int]] = {}
|
||||||
|
for word, weights in word_weights.items():
|
||||||
|
total = sum(weights.values())
|
||||||
|
new_weights: List[int] = []
|
||||||
|
for category in names:
|
||||||
|
new_weights.append(
|
||||||
|
round((weights.get(category, 0) / total) * 65535)
|
||||||
|
)
|
||||||
|
model[word] = new_weights
|
||||||
|
|
||||||
|
return gptc.model.Model(model, names, max_ngram_length)
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
import gptc.tokenizer
|
import gptc.tokenizer
|
||||||
from gptc.exceptions import InvalidModelError
|
from gptc.exceptions import InvalidModelError
|
||||||
import gptc.weighting
|
import gptc.weighting
|
||||||
from typing import Iterable, Mapping, List, Dict, Union, cast, BinaryIO
|
from typing import Iterable, Mapping, List, Dict, Union, cast
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,12 +13,10 @@ class Model:
|
||||||
weights: Dict[int, List[int]],
|
weights: Dict[int, List[int]],
|
||||||
names: List[str],
|
names: List[str],
|
||||||
max_ngram_length: int,
|
max_ngram_length: int,
|
||||||
hash_algorithm: str,
|
|
||||||
):
|
):
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.names = names
|
self.names = names
|
||||||
self.max_ngram_length = max_ngram_length
|
self.max_ngram_length = max_ngram_length
|
||||||
self.hash_algorithm = hash_algorithm
|
|
||||||
|
|
||||||
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
def confidence(self, text: str, max_ngram_length: int) -> Dict[str, float]:
|
||||||
"""Classify text with confidence.
|
"""Classify text with confidence.
|
||||||
|
@ -42,10 +40,7 @@ class Model:
|
||||||
model = self.weights
|
model = self.weights
|
||||||
|
|
||||||
tokens = gptc.tokenizer.hash(
|
tokens = gptc.tokenizer.hash(
|
||||||
gptc.tokenizer.tokenize(
|
gptc.tokenizer.tokenize(text, min(max_ngram_length, self.max_ngram_length))
|
||||||
text, min(max_ngram_length, self.max_ngram_length)
|
|
||||||
),
|
|
||||||
self.hash_algorithm,
|
|
||||||
)
|
)
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
for word in tokens:
|
for word in tokens:
|
||||||
|
@ -79,39 +74,42 @@ class Model:
|
||||||
for index, category in enumerate(self.names)
|
for index, category in enumerate(self.names)
|
||||||
}
|
}
|
||||||
|
|
||||||
def serialize(self, file: BinaryIO):
|
def serialize(self) -> bytes:
|
||||||
file.write(b"GPTC model v5\n")
|
out = b"GPTC model v4\n"
|
||||||
file.write(
|
out += (
|
||||||
json.dumps(
|
json.dumps(
|
||||||
{
|
{
|
||||||
"names": self.names,
|
"names": self.names,
|
||||||
"max_ngram_length": self.max_ngram_length,
|
"max_ngram_length": self.max_ngram_length,
|
||||||
"hash_algorithm": self.hash_algorithm,
|
"has_emoji": True,
|
||||||
|
# Due to an oversight in development, version 3.0.0 still
|
||||||
|
# had the code used to make emoji support optional, even
|
||||||
|
# though the `emoji` library was made a hard dependency.
|
||||||
|
# Part of this code checked whether or not the model
|
||||||
|
# supports emoji; deserialization would not work in 3.0.0
|
||||||
|
# if the model was compiled without this field. Emoji are
|
||||||
|
# always supported with 3.0.0 and newer when GPTC has been
|
||||||
|
# installed correctly, so this value should always be True.
|
||||||
|
# Related: #11
|
||||||
}
|
}
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
+ b"\n"
|
+ b"\n"
|
||||||
)
|
)
|
||||||
for word, weights in self.weights.items():
|
for word, weights in self.weights.items():
|
||||||
file.write(
|
out += word.to_bytes(6, "big") + b"".join(
|
||||||
word.to_bytes(6, "big")
|
[weight.to_bytes(2, "big") for weight in weights]
|
||||||
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
|
|
||||||
)
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
def deserialize(encoded_model: bytes) -> Model:
|
||||||
prefix = encoded_model.read(14)
|
try:
|
||||||
if prefix != b"GPTC model v5\n":
|
prefix, config_json, encoded_weights = encoded_model.split(b"\n", 2)
|
||||||
|
except ValueError:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
config_json = b""
|
if prefix != b"GPTC model v4":
|
||||||
while True:
|
raise InvalidModelError()
|
||||||
byte = encoded_model.read(1)
|
|
||||||
if byte == b"\n":
|
|
||||||
break
|
|
||||||
elif byte == b"":
|
|
||||||
raise InvalidModelError()
|
|
||||||
else:
|
|
||||||
config_json += byte
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = json.loads(config_json.decode("utf-8"))
|
config = json.loads(config_json.decode("utf-8"))
|
||||||
|
@ -121,29 +119,30 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
||||||
try:
|
try:
|
||||||
names = config["names"]
|
names = config["names"]
|
||||||
max_ngram_length = config["max_ngram_length"]
|
max_ngram_length = config["max_ngram_length"]
|
||||||
hash_algorithm = config["hash_algorithm"]
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
if not (
|
if not (isinstance(names, list) and isinstance(max_ngram_length, int)) or not all(
|
||||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
[isinstance(name, str) for name in names]
|
||||||
) or not all([isinstance(name, str) for name in names]):
|
):
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
weight_code_length = 6 + 2 * len(names)
|
weight_code_length = 6 + 2 * len(names)
|
||||||
|
|
||||||
weights: Dict[int : List[int]] = {}
|
if len(encoded_weights) % weight_code_length != 0:
|
||||||
|
raise InvalidModelError()
|
||||||
|
|
||||||
while True:
|
weight_codes = [
|
||||||
code = encoded_model.read(weight_code_length)
|
encoded_weights[x : x + weight_code_length]
|
||||||
if not code:
|
for x in range(0, len(encoded_weights), weight_code_length)
|
||||||
break
|
]
|
||||||
elif len(code) != weight_code_length:
|
|
||||||
raise InvalidModelError()
|
|
||||||
|
|
||||||
weights[int.from_bytes(code[:6], "big")] = [
|
weights = {
|
||||||
|
int.from_bytes(code[:6], "big"): [
|
||||||
int.from_bytes(value, "big")
|
int.from_bytes(value, "big")
|
||||||
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
for value in [code[x : x + 2] for x in range(6, len(code), 2)]
|
||||||
]
|
]
|
||||||
|
for code in weight_codes
|
||||||
|
}
|
||||||
|
|
||||||
return Model(weights, names, max_ngram_length, hash_algorithm)
|
return Model(weights, names, max_ngram_length)
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||||
|
|
||||||
from typing import List, Union, Callable
|
from typing import List, Union
|
||||||
import hashlib
|
import hashlib
|
||||||
import emoji
|
import emoji
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
text = unicodedata.normalize("NFKD", text).lower()
|
text = text.lower()
|
||||||
parts = []
|
parts = []
|
||||||
highest_end = 0
|
highest_end = 0
|
||||||
for emoji_part in emoji.emoji_list(text):
|
for emoji_part in emoji.emoji_list(text):
|
||||||
|
@ -20,12 +19,7 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
tokens = [""]
|
tokens = [""]
|
||||||
|
|
||||||
for char in converted_text:
|
for char in converted_text:
|
||||||
if (
|
if char.isalpha() or char == "'":
|
||||||
char.isalpha()
|
|
||||||
or char.isnumeric()
|
|
||||||
or char == "'"
|
|
||||||
or (char in ",." and (" " + tokens[-1])[-1].isnumeric())
|
|
||||||
):
|
|
||||||
tokens[-1] += char
|
tokens[-1] += char
|
||||||
elif emoji.is_emoji(char):
|
elif emoji.is_emoji(char):
|
||||||
tokens.append(char)
|
tokens.append(char)
|
||||||
|
@ -45,33 +39,14 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||||
return ngrams
|
return ngrams
|
||||||
|
|
||||||
|
|
||||||
def hash_single(token: str, hash_function: Callable) -> int:
|
def hash_single(token: str) -> int:
|
||||||
return int.from_bytes(
|
return int.from_bytes(
|
||||||
hash_function(token.encode("utf-8")).digest()[:6], "big"
|
hashlib.sha256(token.encode("utf-8")).digest()[:6], "big"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
def hash(tokens: List[str]) -> List[int]:
|
||||||
if hash_algorithm in {
|
return [hash_single(token) for token in tokens]
|
||||||
"sha224",
|
|
||||||
"md5",
|
|
||||||
"sha512",
|
|
||||||
"sha3_256",
|
|
||||||
"blake2s",
|
|
||||||
"sha3_224",
|
|
||||||
"sha1",
|
|
||||||
"sha256",
|
|
||||||
"sha384",
|
|
||||||
"shake_256",
|
|
||||||
"blake2b",
|
|
||||||
"sha3_512",
|
|
||||||
"shake_128",
|
|
||||||
"sha3_384",
|
|
||||||
}:
|
|
||||||
hash_function = getattr(hashlib, hash_algorithm)
|
|
||||||
return [hash_single(token, hash_function) for token in tokens]
|
|
||||||
else:
|
|
||||||
raise ValueError("not a valid hash function: " + hash_algorithm)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(text: str) -> str:
|
def normalize(text: str) -> str:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user