Clean up code

This commit is contained in:
Samuel Sloniker 2023-04-17 20:59:39 -07:00
parent 9513025e60
commit a252a15e9d
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
7 changed files with 78 additions and 93 deletions

View File

@ -2,12 +2,11 @@
"""General-Purpose Text Classifier""" """General-Purpose Text Classifier"""
from gptc.compiler import compile as compile from gptc.pack import pack
from gptc.pack import pack as pack from gptc.model import Model, deserialize
from gptc.model import Model as Model, deserialize as deserialize from gptc.tokenizer import normalize
from gptc.tokenizer import normalize as normalize
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError as GPTCError, GPTCError,
ModelError as ModelError, ModelError,
InvalidModelError as InvalidModelError, InvalidModelError,
) )

View File

@ -59,16 +59,16 @@ def main() -> None:
args = parser.parse_args() args = parser.parse_args()
if args.subparser_name == "compile": if args.subparser_name == "compile":
with open(args.model, "r") as f: with open(args.model, "r", encoding="utf-8") as input_file:
model = json.load(f) model = json.load(input_file)
with open(args.out, "wb+") as f: with open(args.out, "wb+") as output_file:
gptc.compile( gptc.Model.compile(
model, args.max_ngram_length, args.min_count model, args.max_ngram_length, args.min_count
).serialize(f) ).serialize(output_file)
elif args.subparser_name == "classify": elif args.subparser_name == "classify":
with open(args.model, "rb") as f: with open(args.model, "rb") as model_file:
model = gptc.deserialize(f) model = gptc.deserialize(model_file)
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input("Text to analyse: ") text = input("Text to analyse: ")
@ -77,8 +77,8 @@ def main() -> None:
print(json.dumps(model.confidence(text, args.max_ngram_length))) print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check": elif args.subparser_name == "check":
with open(args.model, "rb") as f: with open(args.model, "rb") as model_file:
model = gptc.deserialize(f) model = gptc.deserialize(model_file)
print(json.dumps(model.get(args.token))) print(json.dumps(model.get(args.token)))
else: else:
print(json.dumps(gptc.pack(args.model, True)[0])) print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from typing import Iterable, Mapping, List, Dict, Tuple
import gptc.tokenizer import gptc.tokenizer
import gptc.model import gptc.model
from typing import Iterable, Mapping, List, Dict, Union, Tuple
def _count_words( def _count_words(
@ -15,7 +15,7 @@ def _count_words(
names: List[str] = [] names: List[str] = []
for portion in raw_model: for portion in raw_model:
text = gptc.tokenizer.hash( text = gptc.tokenizer.hash_list(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length), gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm, hash_algorithm,
) )
@ -63,7 +63,7 @@ def _get_weights(
return model return model
def compile( def compile_(
raw_model: Iterable[Mapping[str, str]], raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1, max_ngram_length: int = 1,
min_count: int = 1, min_count: int = 1,

View File

@ -1,22 +1,18 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
from typing import ( from typing import (
Iterable,
Mapping,
List, List,
Dict, Dict,
Union,
cast, cast,
BinaryIO, BinaryIO,
Any,
Tuple, Tuple,
TypedDict, TypedDict,
) )
import json import json
import collections import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
import gptc.compiler
class ExplanationEntry(TypedDict): class ExplanationEntry(TypedDict):
@ -33,6 +29,21 @@ Explanation = Dict[
Log = List[Tuple[str, float, List[float]]] Log = List[Tuple[str, float, List[float]]]
class Confidences(dict[str, float]):
def __init__(self, probs: Dict[str, float]):
dict.__init__(self, probs)
class TransparentConfidences(Confidences):
def __init__(
self,
probs: Dict[str, float],
explanation: Explanation,
):
self.explanation = explanation
Confidences.__init__(self, probs)
def convert_log(log: Log, names: List[str]) -> Explanation: def convert_log(log: Log, names: List[str]) -> Explanation:
explanation: Explanation = {} explanation: Explanation = {}
for word2, weight, word_probs in log: for word2, weight, word_probs in log:
@ -49,33 +60,6 @@ def convert_log(log: Log, names: List[str]) -> Explanation:
return explanation return explanation
class Confidences(collections.UserDict[str, float]):
def __init__(
self,
probs: Dict[str, float],
model: Model,
text: str,
max_ngram_length: int,
):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(
self,
probs: Dict[str, float],
explanation: Explanation,
model: Model,
text: str,
max_ngram_length: int,
):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
class Model: class Model:
def __init__( def __init__(
self, self,
@ -117,7 +101,7 @@ class Model:
text, min(max_ngram_length, self.max_ngram_length) text, min(max_ngram_length, self.max_ngram_length)
) )
tokens = gptc.tokenizer.hash( tokens = gptc.tokenizer.hash_list(
raw_tokens, raw_tokens,
self.hash_algorithm, self.hash_algorithm,
) )
@ -163,12 +147,9 @@ class Model:
if transparent: if transparent:
explanation = convert_log(log, self.names) explanation = convert_log(log, self.names)
return TransparentConfidences(probs, explanation)
return TransparentConfidences( return Confidences(probs)
probs, explanation, self, text, max_ngram_length
)
else:
return Confidences(probs, self, text, max_ngram_length)
def get(self, token: str) -> Dict[str, float]: def get(self, token: str) -> Dict[str, float]:
try: try:
@ -202,6 +183,8 @@ class Model:
+ b"".join([weight.to_bytes(2, "big") for weight in weights]) + b"".join([weight.to_bytes(2, "big") for weight in weights])
) )
compile = staticmethod(gptc.compiler.compile_)
def deserialize(encoded_model: BinaryIO) -> Model: def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14) prefix = encoded_model.read(14)
@ -213,26 +196,27 @@ def deserialize(encoded_model: BinaryIO) -> Model:
byte = encoded_model.read(1) byte = encoded_model.read(1)
if byte == b"\n": if byte == b"\n":
break break
elif byte == b"":
if byte == b"":
raise InvalidModelError() raise InvalidModelError()
else:
config_json += byte config_json += byte
try: try:
config = json.loads(config_json.decode("utf-8")) config = json.loads(config_json.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError): except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() raise InvalidModelError() from exc
try: try:
names = config["names"] names = config["names"]
max_ngram_length = config["max_ngram_length"] max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"] hash_algorithm = config["hash_algorithm"]
except KeyError: except KeyError as exc:
raise InvalidModelError() raise InvalidModelError() from exc
if not ( if not (
isinstance(names, list) and isinstance(max_ngram_length, int) isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]): ) or not all(isinstance(name, str) for name in names):
raise InvalidModelError() raise InvalidModelError()
weight_code_length = 6 + 2 * len(names) weight_code_length = 6 + 2 * len(names)
@ -243,7 +227,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
code = encoded_model.read(weight_code_length) code = encoded_model.read(weight_code_length)
if not code: if not code:
break break
elif len(code) != weight_code_length: if len(code) != weight_code_length:
raise InvalidModelError() raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [ weights[int.from_bytes(code[:6], "big")] = [

View File

@ -7,7 +7,7 @@ from typing import List, Dict, Tuple
def pack( def pack(
directory: str, print_exceptions: bool = False directory: str, print_exceptions: bool = False
) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]: ) -> Tuple[List[Dict[str, str]], List[Tuple[OSError]]]:
paths = os.listdir(directory) paths = os.listdir(directory)
texts: Dict[str, List[str]] = {} texts: Dict[str, List[str]] = {}
exceptions = [] exceptions = []
@ -17,16 +17,18 @@ def pack(
try: try:
for file in os.listdir(os.path.join(directory, path)): for file in os.listdir(os.path.join(directory, path)):
try: try:
with open(os.path.join(directory, path, file)) as f: with open(
texts[path].append(f.read()) os.path.join(directory, path, file), encoding="utf-8"
except Exception as e: ) as input_file:
exceptions.append((e,)) texts[path].append(input_file.read())
except OSError as error:
exceptions.append((error,))
if print_exceptions: if print_exceptions:
print(e, file=sys.stderr) print(error, file=sys.stderr)
except Exception as e: except OSError as error:
exceptions.append((e,)) exceptions.append((error,))
if print_exceptions: if print_exceptions:
print(e, file=sys.stderr) print(error, file=sys.stderr)
raw_model = [] raw_model = []

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union, Callable, Any, cast import unicodedata
from typing import List, cast
import hashlib import hashlib
import emoji import emoji
import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
@ -37,7 +37,7 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
if max_ngram_length == 1: if max_ngram_length == 1:
return tokens return tokens
else:
ngrams = [] ngrams = []
for ngram_length in range(1, max_ngram_length + 1): for ngram_length in range(1, max_ngram_length + 1):
for index in range(len(tokens) + 1 - ngram_length): for index in range(len(tokens) + 1 - ngram_length):
@ -69,7 +69,7 @@ def _get_hash_function(hash_algorithm: str) -> type:
"sha3_384", "sha3_384",
}: }:
return cast(type, getattr(hashlib, hash_algorithm)) return cast(type, getattr(hashlib, hash_algorithm))
else:
raise ValueError("not a valid hash function: " + hash_algorithm) raise ValueError("not a valid hash function: " + hash_algorithm)
@ -77,7 +77,7 @@ def hash_single(token: str, hash_algorithm: str) -> int:
return _hash_single(token, _get_hash_function(hash_algorithm)) return _hash_single(token, _get_hash_function(hash_algorithm))
def hash(tokens: List[str], hash_algorithm: str) -> List[int]: def hash_list(tokens: List[str], hash_algorithm: str) -> List[int]:
hash_function = _get_hash_function(hash_algorithm) hash_function = _get_hash_function(hash_algorithm)
return [_hash_single(token, hash_function) for token in tokens] return [_hash_single(token, hash_function) for token in tokens]

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import math import math
from typing import Sequence, Union, Tuple, List from typing import Sequence, Tuple, List
def _mean(numbers: Sequence[float]) -> float: def _mean(numbers: Sequence[float]) -> float:
@ -41,6 +41,6 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]: def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
standard_deviation = _standard_deviation(numbers) standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2 weight_assigned = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers] weighted_numbers = [i * weight_assigned for i in numbers]
return weight, weighted_numbers return weight_assigned, weighted_numbers