Clean up code
This commit is contained in:
parent
9513025e60
commit
a252a15e9d
|
@ -2,12 +2,11 @@
|
|||
|
||||
"""General-Purpose Text Classifier"""
|
||||
|
||||
from gptc.compiler import compile as compile
|
||||
from gptc.pack import pack as pack
|
||||
from gptc.model import Model as Model, deserialize as deserialize
|
||||
from gptc.tokenizer import normalize as normalize
|
||||
from gptc.pack import pack
|
||||
from gptc.model import Model, deserialize
|
||||
from gptc.tokenizer import normalize
|
||||
from gptc.exceptions import (
|
||||
GPTCError as GPTCError,
|
||||
ModelError as ModelError,
|
||||
InvalidModelError as InvalidModelError,
|
||||
GPTCError,
|
||||
ModelError,
|
||||
InvalidModelError,
|
||||
)
|
||||
|
|
|
@ -59,16 +59,16 @@ def main() -> None:
|
|||
args = parser.parse_args()
|
||||
|
||||
if args.subparser_name == "compile":
|
||||
with open(args.model, "r") as f:
|
||||
model = json.load(f)
|
||||
with open(args.model, "r", encoding="utf-8") as input_file:
|
||||
model = json.load(input_file)
|
||||
|
||||
with open(args.out, "wb+") as f:
|
||||
gptc.compile(
|
||||
with open(args.out, "wb+") as output_file:
|
||||
gptc.Model.compile(
|
||||
model, args.max_ngram_length, args.min_count
|
||||
).serialize(f)
|
||||
).serialize(output_file)
|
||||
elif args.subparser_name == "classify":
|
||||
with open(args.model, "rb") as f:
|
||||
model = gptc.deserialize(f)
|
||||
with open(args.model, "rb") as model_file:
|
||||
model = gptc.deserialize(model_file)
|
||||
|
||||
if sys.stdin.isatty():
|
||||
text = input("Text to analyse: ")
|
||||
|
@ -77,8 +77,8 @@ def main() -> None:
|
|||
|
||||
print(json.dumps(model.confidence(text, args.max_ngram_length)))
|
||||
elif args.subparser_name == "check":
|
||||
with open(args.model, "rb") as f:
|
||||
model = gptc.deserialize(f)
|
||||
with open(args.model, "rb") as model_file:
|
||||
model = gptc.deserialize(model_file)
|
||||
print(json.dumps(model.get(args.token)))
|
||||
else:
|
||||
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
from typing import Iterable, Mapping, List, Dict, Tuple
|
||||
import gptc.tokenizer
|
||||
import gptc.model
|
||||
from typing import Iterable, Mapping, List, Dict, Union, Tuple
|
||||
|
||||
|
||||
def _count_words(
|
||||
|
@ -15,7 +15,7 @@ def _count_words(
|
|||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
text = gptc.tokenizer.hash(
|
||||
text = gptc.tokenizer.hash_list(
|
||||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||
hash_algorithm,
|
||||
)
|
||||
|
@ -63,7 +63,7 @@ def _get_weights(
|
|||
return model
|
||||
|
||||
|
||||
def compile(
|
||||
def compile_(
|
||||
raw_model: Iterable[Mapping[str, str]],
|
||||
max_ngram_length: int = 1,
|
||||
min_count: int = 1,
|
||||
|
|
|
@ -1,22 +1,18 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import gptc.tokenizer
|
||||
from gptc.exceptions import InvalidModelError
|
||||
import gptc.weighting
|
||||
from typing import (
|
||||
Iterable,
|
||||
Mapping,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
cast,
|
||||
BinaryIO,
|
||||
Any,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
)
|
||||
import json
|
||||
import collections
|
||||
import gptc.tokenizer
|
||||
from gptc.exceptions import InvalidModelError
|
||||
import gptc.weighting
|
||||
import gptc.compiler
|
||||
|
||||
|
||||
class ExplanationEntry(TypedDict):
|
||||
|
@ -33,6 +29,21 @@ Explanation = Dict[
|
|||
Log = List[Tuple[str, float, List[float]]]
|
||||
|
||||
|
||||
class Confidences(dict[str, float]):
|
||||
def __init__(self, probs: Dict[str, float]):
|
||||
dict.__init__(self, probs)
|
||||
|
||||
|
||||
class TransparentConfidences(Confidences):
|
||||
def __init__(
|
||||
self,
|
||||
probs: Dict[str, float],
|
||||
explanation: Explanation,
|
||||
):
|
||||
self.explanation = explanation
|
||||
Confidences.__init__(self, probs)
|
||||
|
||||
|
||||
def convert_log(log: Log, names: List[str]) -> Explanation:
|
||||
explanation: Explanation = {}
|
||||
for word2, weight, word_probs in log:
|
||||
|
@ -49,33 +60,6 @@ def convert_log(log: Log, names: List[str]) -> Explanation:
|
|||
return explanation
|
||||
|
||||
|
||||
class Confidences(collections.UserDict[str, float]):
|
||||
def __init__(
|
||||
self,
|
||||
probs: Dict[str, float],
|
||||
model: Model,
|
||||
text: str,
|
||||
max_ngram_length: int,
|
||||
):
|
||||
collections.UserDict.__init__(self, probs)
|
||||
self.model = model
|
||||
self.text = text
|
||||
self.max_ngram_length = max_ngram_length
|
||||
|
||||
|
||||
class TransparentConfidences(Confidences):
|
||||
def __init__(
|
||||
self,
|
||||
probs: Dict[str, float],
|
||||
explanation: Explanation,
|
||||
model: Model,
|
||||
text: str,
|
||||
max_ngram_length: int,
|
||||
):
|
||||
Confidences.__init__(self, probs, model, text, max_ngram_length)
|
||||
self.explanation = explanation
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -117,7 +101,7 @@ class Model:
|
|||
text, min(max_ngram_length, self.max_ngram_length)
|
||||
)
|
||||
|
||||
tokens = gptc.tokenizer.hash(
|
||||
tokens = gptc.tokenizer.hash_list(
|
||||
raw_tokens,
|
||||
self.hash_algorithm,
|
||||
)
|
||||
|
@ -163,12 +147,9 @@ class Model:
|
|||
|
||||
if transparent:
|
||||
explanation = convert_log(log, self.names)
|
||||
return TransparentConfidences(probs, explanation)
|
||||
|
||||
return TransparentConfidences(
|
||||
probs, explanation, self, text, max_ngram_length
|
||||
)
|
||||
else:
|
||||
return Confidences(probs, self, text, max_ngram_length)
|
||||
return Confidences(probs)
|
||||
|
||||
def get(self, token: str) -> Dict[str, float]:
|
||||
try:
|
||||
|
@ -202,6 +183,8 @@ class Model:
|
|||
+ b"".join([weight.to_bytes(2, "big") for weight in weights])
|
||||
)
|
||||
|
||||
compile = staticmethod(gptc.compiler.compile_)
|
||||
|
||||
|
||||
def deserialize(encoded_model: BinaryIO) -> Model:
|
||||
prefix = encoded_model.read(14)
|
||||
|
@ -213,26 +196,27 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
|||
byte = encoded_model.read(1)
|
||||
if byte == b"\n":
|
||||
break
|
||||
elif byte == b"":
|
||||
|
||||
if byte == b"":
|
||||
raise InvalidModelError()
|
||||
else:
|
||||
config_json += byte
|
||||
|
||||
config_json += byte
|
||||
|
||||
try:
|
||||
config = json.loads(config_json.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
raise InvalidModelError()
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
||||
raise InvalidModelError() from exc
|
||||
|
||||
try:
|
||||
names = config["names"]
|
||||
max_ngram_length = config["max_ngram_length"]
|
||||
hash_algorithm = config["hash_algorithm"]
|
||||
except KeyError:
|
||||
raise InvalidModelError()
|
||||
except KeyError as exc:
|
||||
raise InvalidModelError() from exc
|
||||
|
||||
if not (
|
||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
||||
) or not all([isinstance(name, str) for name in names]):
|
||||
) or not all(isinstance(name, str) for name in names):
|
||||
raise InvalidModelError()
|
||||
|
||||
weight_code_length = 6 + 2 * len(names)
|
||||
|
@ -243,7 +227,7 @@ def deserialize(encoded_model: BinaryIO) -> Model:
|
|||
code = encoded_model.read(weight_code_length)
|
||||
if not code:
|
||||
break
|
||||
elif len(code) != weight_code_length:
|
||||
if len(code) != weight_code_length:
|
||||
raise InvalidModelError()
|
||||
|
||||
weights[int.from_bytes(code[:6], "big")] = [
|
||||
|
|
20
gptc/pack.py
20
gptc/pack.py
|
@ -7,7 +7,7 @@ from typing import List, Dict, Tuple
|
|||
|
||||
def pack(
|
||||
directory: str, print_exceptions: bool = False
|
||||
) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]:
|
||||
) -> Tuple[List[Dict[str, str]], List[Tuple[OSError]]]:
|
||||
paths = os.listdir(directory)
|
||||
texts: Dict[str, List[str]] = {}
|
||||
exceptions = []
|
||||
|
@ -17,16 +17,18 @@ def pack(
|
|||
try:
|
||||
for file in os.listdir(os.path.join(directory, path)):
|
||||
try:
|
||||
with open(os.path.join(directory, path, file)) as f:
|
||||
texts[path].append(f.read())
|
||||
except Exception as e:
|
||||
exceptions.append((e,))
|
||||
with open(
|
||||
os.path.join(directory, path, file), encoding="utf-8"
|
||||
) as input_file:
|
||||
texts[path].append(input_file.read())
|
||||
except OSError as error:
|
||||
exceptions.append((error,))
|
||||
if print_exceptions:
|
||||
print(e, file=sys.stderr)
|
||||
except Exception as e:
|
||||
exceptions.append((e,))
|
||||
print(error, file=sys.stderr)
|
||||
except OSError as error:
|
||||
exceptions.append((error,))
|
||||
if print_exceptions:
|
||||
print(e, file=sys.stderr)
|
||||
print(error, file=sys.stderr)
|
||||
|
||||
raw_model = []
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
from typing import List, Union, Callable, Any, cast
|
||||
import unicodedata
|
||||
from typing import List, cast
|
||||
import hashlib
|
||||
import emoji
|
||||
import unicodedata
|
||||
|
||||
|
||||
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
||||
|
@ -37,12 +37,12 @@ def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
|
|||
|
||||
if max_ngram_length == 1:
|
||||
return tokens
|
||||
else:
|
||||
ngrams = []
|
||||
for ngram_length in range(1, max_ngram_length + 1):
|
||||
for index in range(len(tokens) + 1 - ngram_length):
|
||||
ngrams.append(" ".join(tokens[index : index + ngram_length]))
|
||||
return ngrams
|
||||
|
||||
ngrams = []
|
||||
for ngram_length in range(1, max_ngram_length + 1):
|
||||
for index in range(len(tokens) + 1 - ngram_length):
|
||||
ngrams.append(" ".join(tokens[index : index + ngram_length]))
|
||||
return ngrams
|
||||
|
||||
|
||||
def _hash_single(token: str, hash_function: type) -> int:
|
||||
|
@ -69,15 +69,15 @@ def _get_hash_function(hash_algorithm: str) -> type:
|
|||
"sha3_384",
|
||||
}:
|
||||
return cast(type, getattr(hashlib, hash_algorithm))
|
||||
else:
|
||||
raise ValueError("not a valid hash function: " + hash_algorithm)
|
||||
|
||||
raise ValueError("not a valid hash function: " + hash_algorithm)
|
||||
|
||||
|
||||
def hash_single(token: str, hash_algorithm: str) -> int:
|
||||
return _hash_single(token, _get_hash_function(hash_algorithm))
|
||||
|
||||
|
||||
def hash(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||
def hash_list(tokens: List[str], hash_algorithm: str) -> List[int]:
|
||||
hash_function = _get_hash_function(hash_algorithm)
|
||||
return [_hash_single(token, hash_function) for token in tokens]
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
import math
|
||||
from typing import Sequence, Union, Tuple, List
|
||||
from typing import Sequence, Tuple, List
|
||||
|
||||
|
||||
def _mean(numbers: Sequence[float]) -> float:
|
||||
|
@ -41,6 +41,6 @@ def _standard_deviation(numbers: Sequence[float]) -> float:
|
|||
|
||||
def weight(numbers: Sequence[float]) -> Tuple[float, List[float]]:
|
||||
standard_deviation = _standard_deviation(numbers)
|
||||
weight = standard_deviation * 2
|
||||
weighted_numbers = [i * weight for i in numbers]
|
||||
return weight, weighted_numbers
|
||||
weight_assigned = standard_deviation * 2
|
||||
weighted_numbers = [i * weight_assigned for i in numbers]
|
||||
return weight_assigned, weighted_numbers
|
||||
|
|
Loading…
Reference in New Issue
Block a user