diff --git a/gptc/__init__.py b/gptc/__init__.py index f4d36d5..6ef26b3 100644 --- a/gptc/__init__.py +++ b/gptc/__init__.py @@ -4,6 +4,7 @@ from gptc.compiler import compile as compile from gptc.classifier import Classifier as Classifier +from gptc.pack import pack as pack from gptc.exceptions import ( GPTCError as GPTCError, ModelError as ModelError, diff --git a/gptc/__main__.py b/gptc/__main__.py index dc14d18..0f64989 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -13,7 +13,9 @@ def main() -> None: ) subparsers = parser.add_subparsers(dest="subparser_name", required=True) - compile_parser = subparsers.add_parser("compile", help="compile a raw model") + compile_parser = subparsers.add_parser( + "compile", help="compile a raw model" + ) compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument( "--max-ngram-length", @@ -46,14 +48,22 @@ def main() -> None: action="store_true", ) + pack_parser = subparsers.add_parser( + "pack", help="pack a model from a directory" + ) + pack_parser.add_argument("model", help="directory containing model") + args = parser.parse_args() - with open(args.model, "r") as f: - model = json.load(f) - if args.subparser_name == "compile": + with open(args.model, "r") as f: + model = json.load(f) + print(json.dumps(gptc.compile(model, args.max_ngram_length))) - else: + elif args.subparser_name == "classify": + with open(args.model, "r") as f: + model = json.load(f) + classifier = gptc.Classifier(model, args.max_ngram_length) if sys.stdin.isatty(): @@ -65,6 +75,8 @@ def main() -> None: print(classifier.classify(text)) else: print(json.dumps(classifier.confidence(text))) + else: + print(json.dumps(gptc.pack(args.model, True)[0])) if __name__ == "__main__": diff --git a/gptc/classifier.py b/gptc/classifier.py index 2469ff3..d4cd1dc 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -27,7 +27,9 @@ class Classifier: def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1): if model.get("__version__", 0) != 3: - raise gptc.exceptions.UnsupportedModelError(f"unsupported model version") + raise gptc.exceptions.UnsupportedModelError( + f"unsupported model version" + ) self.model = model model_ngrams = cast(int, model.get("__ngrams__", 1)) self.max_ngram_length = min(max_ngram_length, model_ngrams) diff --git a/gptc/compiler.py b/gptc/compiler.py index fd4c12e..05b793b 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -8,7 +8,9 @@ CONFIG_T = Union[List[str], int, str] MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] -def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL: +def compile( + raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1 +) -> MODEL: """Compile a raw model. Parameters @@ -47,9 +49,13 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) - categories_by_count[category] = {} for word in text: try: - categories_by_count[category][word] += 1 / len(categories[category]) + categories_by_count[category][word] += 1 / len( + categories[category] + ) except KeyError: - categories_by_count[category][word] = 1 / len(categories[category]) + categories_by_count[category][word] = 1 / len( + categories[category] + ) word_weights: Dict[str, Dict[str, float]] = {} for category, words in categories_by_count.items(): for word, value in words.items(): @@ -63,7 +69,9 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) - total = sum(weights.values()) new_weights: List[int] = [] for category in names: - new_weights.append(round((weights.get(category, 0) / total) * 65535)) + new_weights.append( + round((weights.get(category, 0) / total) * 65535) + ) model[word] = new_weights model["__names__"] = names diff --git a/utils/pack.py b/gptc/pack.py similarity index 69% rename from utils/pack.py rename to gptc/pack.py index b2f7589..4177ad8 100644 --- a/utils/pack.py +++ b/gptc/pack.py @@ -2,10 +2,9 @@ import sys import os -import json -def pack(directory, print_exceptions=True): +def pack(directory, print_exceptions=False): paths = os.listdir(directory) texts = {} exceptions = [] @@ -13,9 +12,9 @@ def pack(directory, print_exceptions=True): for path in paths: texts[path] = [] try: - for file in os.listdir(os.path.join(sys.argv[1], path)): + for file in os.listdir(os.path.join(directory, path)): try: - with open(os.path.join(sys.argv[1], path, file)) as f: + with open(os.path.join(directory, path, file)) as f: texts[path].append(f.read()) except Exception as e: exceptions.append((e,)) @@ -32,10 +31,3 @@ def pack(directory, print_exceptions=True): raw_model += [{"category": category, "text": i} for i in cat_texts] return raw_model, exceptions - - -if len(sys.argv) != 2: - print("usage: pack.py ", file=sys.stderr) - exit(1) - -print(json.dumps(pack(sys.argv[1])[0]))