Clean up code

This commit is contained in:
Samuel Sloniker 2023-04-17 20:59:39 -07:00
parent 9513025e60
commit a252a15e9d
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62
7 changed files with 78 additions and 93 deletions

View File

@ -2,12 +2,11 @@
"""General-Purpose Text Classifier"""
from gptc.compiler import compile as compile
from gptc.pack import pack as pack
from gptc.model import Model as Model, deserialize as deserialize
from gptc.tokenizer import normalize as normalize
from gptc.pack import pack
from gptc.model import Model, deserialize
from gptc.tokenizer import normalize
from gptc.exceptions import (
GPTCError as GPTCError,
ModelError as ModelError,
InvalidModelError as InvalidModelError,
GPTCError,
ModelError,
InvalidModelError,
)

View File

@ -59,16 +59,16 @@ def main() -> None:
args = parser.parse_args()
if args.subparser_name == "compile":
with open(args.model, "r") as f:
model = json.load(f)
with open(args.model, "r", encoding="utf-8") as input_file:
model = json.load(input_file)
with open(args.out, "wb+") as f:
gptc.compile(
with open(args.out, "wb+") as output_file:
gptc.Model.compile(
model, args.max_ngram_length, args.min_count
).serialize(f)
).serialize(output_file)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)
with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file)
if sys.stdin.isatty():
text = input("Text to analyse: ")
@ -77,8 +77,8 @@ def main() -> None:
print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)
with open(args.model, "rb") as model_file:
model = gptc.deserialize(model_file)
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Iterable, Mapping, List, Dict, Tuple
import gptc.tokenizer
import gptc.model
from typing import Iterable, Mapping, List, Dict, Union, Tuple
def _count_words(
@ -15,7 +15,7 @@ def _count_words(
names: List[str] = []
for portion in raw_model:
text = gptc.tokenizer.hash(
text = gptc.tokenizer.hash_list(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm,
)
@ -63,7 +63,7 @@ def _get_weights(
return model
def compile(
def compile_(
raw_model: Iterable[Mapping[str, str]],
max_ngram_length: int = 1,
min_count: int = 1,

View File

@ -1,22 +1,18 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
from typing import (
Iterable,
Mapping,
List,
Dict,
Union,
cast,
BinaryIO,
Any,
Tuple,
TypedDict,
)
import json
import collections
import gptc.tokenizer
from gptc.exceptions import InvalidModelError
import gptc.weighting
import gptc.compiler
class ExplanationEntry(TypedDict):
@ -33,6 +29,21 @@ Explanation = Dict[
Log = List[Tuple[str, float, List[float]]]
class Confidences(dict[str, float]):
def __init__(self, probs: Dict[str, float]):
dict.__init__(self, probs)
class TransparentConfidences(Confidences):
def __init__(
self,
probs: Dict[str, float],
explanation: Explanation,
):
self.explanation = explanation
Confidences.__init__(self, probs)
def convert_log(log: Log, names: List[str]) -> Explanation:
explanation: Explanation = {}
for word2, weight, word_probs in log:
@ -49,33 +60,6 @@ def convert_log(log: Log, names: List[str]) -> Explanation:
return explanation
class Confidences(collections.UserDict[str, float]):
def __init__(
self,
probs: Dict[str, float],
model: Model,
text: str,
max_ngram_length: int,
):
collections.UserDict.__init__(self, probs)
self.model = model
self.text = text
self.max_ngram_length = max_ngram_length
class TransparentConfidences(Confidences):
def __init__(
self,
probs: Dict[str, float],
explanation: Explanation,
model: Model,
text: str,
max_ngram_length: int,
):
Confidences.__init__(self, probs, model, text, max_ngram_length)
self.explanation = explanation
class Model:
def __init__(
self,
@ -117,7 +101,7 @@ class Model:
text, min(max_ngram_length, self.max_ngram_length)
)
tokens = gptc.tokenizer.hash(
tokens = gptc.tokenizer.hash_list(
raw_tokens,
self.hash_algorithm,
)
@ -163,12 +147,9 @@ class Model:
if transparent:
explanation = convert_log(log, self.names)
return TransparentConfidences(probs, explanation)
return TransparentConfidences(
probs, explanation, self, text, max_ngram_length
)
else:
return Confidences(probs, self, text, max_ngram_length)
return Confidences(probs)
def get(self, token: str) -> Dict[str, float]:
try:
@ -202,6 +183,8 @@ class Model:
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
)
compile = staticmethod(gptc.compiler.compile_)
def deserialize(encoded_model: BinaryIO) -> Model:
prefix = encoded_model.read(14)
@ -213,26 +196,27 @@ def deserialize(encoded_model: BinaryIO) -> Model:
byte = encoded_model.read(1)
if byte == b"\n":
break
elif byte == b"":
if byte == b"":
raise InvalidModelError()
else:
config_json += byte
config_json += byte
try:
config = json.loads(config_json.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError):
raise InvalidModelError()
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise InvalidModelError() from exc
try:
names = config["names"]
max_ngram_length = config["max_ngram_length"]
hash_algorithm = config["hash_algorithm"]
except KeyError:
raise InvalidModelError()
except KeyError as exc:
raise InvalidModelError() from exc
if not (
isinstance(names, list) and isinstance(max_ngram_length, int)
) or not all([isinstance(name, str) for name in names]):
) or not all(isinstance(name, str) for name in names):
raise InvalidModelError()
weight_code_length = 6 + 2 * len(names)
@ -243,7 +227,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
code = encoded_model.read(weight_code_length)
if not code:
break
elif len(code) != weight_code_length:
if len(code) != weight_code_length:
raise InvalidModelError()
weights[int.from_bytes(code[:6], "big")] = [

View File

@ -7,7 +7,7 @@ from typing import List, Dict, Tuple
def pack(
directory: str, print_exceptions: bool = False
) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]:
) -> Tuple[List[Dict[str, str]], List[Tuple[OSError]]]:
paths = os.listdir(directory)
texts: Dict[str, List[str]] = {}
exceptions = []
@ -17,16 +17,18 @@ def pack(
try:
for file in os.listdir(os.path.join(directory, path)):
try:
with open(os.path.join(directory, path, file)) as f:
texts[path].append(f.read())
except Exception as e:
exceptions.append((e,))
with open(
os.path.join(directory, path, file), encoding="utf-8"
) as input_file:
texts[path].append(input_file.read())
except OSError as error:
exceptions.append((error,))
if print_exceptions:
print(e, file=sys.stderr)
except Exception as e:
exceptions.append((e,))
print(error, file=sys.stderr)
except OSError as error:
exceptions.append((error,))
if print_exceptions:
print(e, file=sys.stderr)
print(error, file=sys.stderr)
raw_model = []

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import List, Union, Callable, Any, cast
import unicodedata
from typing import List, cast
import hashlib
import emoji
import unicodedata
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
@ -37,12 +37,12 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
if max_ngram_length == 1:
return tokens
else:
ngrams = []
for ngram_length in range(1, max_ngram_length + 1):
for index in range(len(tokens) + 1 - ngram_length):
ngrams.append(" ".join(tokens[index : index + ngram_length]))
return ngrams
ngrams = []
for ngram_length in range(1, max_ngram_length + 1):
for index in range(len(tokens) + 1 - ngram_length):
ngrams.append(" ".join(tokens[index : index + ngram_length]))
return ngrams
def _hash_single(token: str, hash_function: type) -> int:
@ -69,15 +69,15 @@ def _get_hash_function(hash_algorithm: str) -> type:
"sha3_384",
}:
return cast(type, getattr(hashlib, hash_algorithm))
else:
raise ValueError("not a valid hash function: " + hash_algorithm)
raise ValueError("not a valid hash function: " + hash_algorithm)
def hash_single(token: str, hash_algorithm: str) -> int:
return _hash_single(token, _get_hash_function(hash_algorithm))
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
def hash_list(tokens: List[str], hash_algorithm: str) -> List[int]:
hash_function = _get_hash_function(hash_algorithm)
return [_hash_single(token, hash_function) for token in tokens]

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import math
from typing import Sequence, Union, Tuple, List
from typing import Sequence, Tuple, List
def _mean(numbers: Sequence[float]) -> float:
@ -41,6 +41,6 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers]
return weight, weighted_numbers
weight_assigned = standard_deviation * 2
weighted_numbers = [i * weight_assigned for i in numbers]
return weight_assigned, weighted_numbers