Move pack to main module; format code

This commit is contained in:
Samuel Sloniker 2022-07-18 16:03:58 -07:00
parent e711767d24
commit 5082c2226b
5 changed files with 36 additions and 21 deletions

View File

@ -4,6 +4,7 @@
from gptc.compiler import compile as compile from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError as GPTCError, GPTCError as GPTCError,
ModelError as ModelError, ModelError as ModelError,

View File

@ -13,7 +13,9 @@ def main() -> None:
) )
subparsers = parser.add_subparsers(dest="subparser_name", required=True) subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser("compile", help="compile a raw model") compile_parser = subparsers.add_parser(
"compile", help="compile a raw model"
)
compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument( compile_parser.add_argument(
"--max-ngram-length", "--max-ngram-length",
@ -46,14 +48,22 @@ def main() -> None:
action="store_true", action="store_true",
) )
pack_parser = subparsers.add_parser(
"pack", help="pack a model from a directory"
)
pack_parser.add_argument("model", help="directory containing model")
args = parser.parse_args() args = parser.parse_args()
if args.subparser_name == "compile":
with open(args.model, "r") as f: with open(args.model, "r") as f:
model = json.load(f) model = json.load(f)
if args.subparser_name == "compile":
print(json.dumps(gptc.compile(model, args.max_ngram_length))) print(json.dumps(gptc.compile(model, args.max_ngram_length)))
else: elif args.subparser_name == "classify":
with open(args.model, "r") as f:
model = json.load(f)
classifier = gptc.Classifier(model, args.max_ngram_length) classifier = gptc.Classifier(model, args.max_ngram_length)
if sys.stdin.isatty(): if sys.stdin.isatty():
@ -65,6 +75,8 @@ def main() -> None:
print(classifier.classify(text)) print(classifier.classify(text))
else: else:
print(json.dumps(classifier.confidence(text))) print(json.dumps(classifier.confidence(text)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,7 +27,9 @@ class Classifier:
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1): def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
if model.get("__version__", 0) != 3: if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version") raise gptc.exceptions.UnsupportedModelError(
f"unsupported model version"
)
self.model = model self.model = model
model_ngrams = cast(int, model.get("__ngrams__", 1)) model_ngrams = cast(int, model.get("__ngrams__", 1))
self.max_ngram_length = min(max_ngram_length, model_ngrams) self.max_ngram_length = min(max_ngram_length, model_ngrams)

View File

@ -8,7 +8,9 @@ CONFIG_T = Union[List[str], int, str]
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL: def compile(
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
) -> MODEL:
"""Compile a raw model. """Compile a raw model.
Parameters Parameters
@ -47,9 +49,13 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -
categories_by_count[category] = {} categories_by_count[category] = {}
for word in text: for word in text:
try: try:
categories_by_count[category][word] += 1 / len(categories[category]) categories_by_count[category][word] += 1 / len(
categories[category]
)
except KeyError: except KeyError:
categories_by_count[category][word] = 1 / len(categories[category]) categories_by_count[category][word] = 1 / len(
categories[category]
)
word_weights: Dict[str, Dict[str, float]] = {} word_weights: Dict[str, Dict[str, float]] = {}
for category, words in categories_by_count.items(): for category, words in categories_by_count.items():
for word, value in words.items(): for word, value in words.items():
@ -63,7 +69,9 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -
total = sum(weights.values()) total = sum(weights.values())
new_weights: List[int] = [] new_weights: List[int] = []
for category in names: for category in names:
new_weights.append(round((weights.get(category, 0) / total) * 65535)) new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights model[word] = new_weights
model["__names__"] = names model["__names__"] = names

View File

@ -2,10 +2,9 @@
import sys import sys
import os import os
import json
def pack(directory, print_exceptions=True): def pack(directory, print_exceptions=False):
paths = os.listdir(directory) paths = os.listdir(directory)
texts = {} texts = {}
exceptions = [] exceptions = []
@ -13,9 +12,9 @@ def pack(directory, print_exceptions=True):
for path in paths: for path in paths:
texts[path] = [] texts[path] = []
try: try:
for file in os.listdir(os.path.join(sys.argv[1], path)): for file in os.listdir(os.path.join(directory, path)):
try: try:
with open(os.path.join(sys.argv[1], path, file)) as f: with open(os.path.join(directory, path, file)) as f:
texts[path].append(f.read()) texts[path].append(f.read())
except Exception as e: except Exception as e:
exceptions.append((e,)) exceptions.append((e,))
@ -32,10 +31,3 @@ def pack(directory, print_exceptions=True):
raw_model += [{"category": category, "text": i} for i in cat_texts] raw_model += [{"category": category, "text": i} for i in cat_texts]
return raw_model, exceptions return raw_model, exceptions
if len(sys.argv) != 2:
print("usage: pack.py <path>", file=sys.stderr)
exit(1)
print(json.dumps(pack(sys.argv[1])[0]))