Compare commits

..

3 Commits

Author SHA1 Message Date
ff8cba84c7 format pack.py 2022-07-19 16:02:05 -07:00
8c6dd0bde9 Type checks for pack 2022-07-19 10:43:10 -07:00
5082c2226b Move pack to main module; format code 2022-07-18 16:03:58 -07:00
5 changed files with 40 additions and 22 deletions

View File

@ -4,6 +4,7 @@
from gptc.compiler import compile as compile from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack
from gptc.exceptions import ( from gptc.exceptions import (
GPTCError as GPTCError, GPTCError as GPTCError,
ModelError as ModelError, ModelError as ModelError,

View File

@ -13,7 +13,9 @@ def main() -> None:
) )
subparsers = parser.add_subparsers(dest="subparser_name", required=True) subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser("compile", help="compile a raw model") compile_parser = subparsers.add_parser(
"compile", help="compile a raw model"
)
compile_parser.add_argument("model", help="raw model to compile") compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument( compile_parser.add_argument(
"--max-ngram-length", "--max-ngram-length",
@ -46,14 +48,22 @@ def main() -> None:
action="store_true", action="store_true",
) )
pack_parser = subparsers.add_parser(
"pack", help="pack a model from a directory"
)
pack_parser.add_argument("model", help="directory containing model")
args = parser.parse_args() args = parser.parse_args()
with open(args.model, "r") as f:
model = json.load(f)
if args.subparser_name == "compile": if args.subparser_name == "compile":
with open(args.model, "r") as f:
model = json.load(f)
print(json.dumps(gptc.compile(model, args.max_ngram_length))) print(json.dumps(gptc.compile(model, args.max_ngram_length)))
else: elif args.subparser_name == "classify":
with open(args.model, "r") as f:
model = json.load(f)
classifier = gptc.Classifier(model, args.max_ngram_length) classifier = gptc.Classifier(model, args.max_ngram_length)
if sys.stdin.isatty(): if sys.stdin.isatty():
@ -65,6 +75,8 @@ def main() -> None:
print(classifier.classify(text)) print(classifier.classify(text))
else: else:
print(json.dumps(classifier.confidence(text))) print(json.dumps(classifier.confidence(text)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,7 +27,9 @@ class Classifier:
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1): def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
if model.get("__version__", 0) != 3: if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version") raise gptc.exceptions.UnsupportedModelError(
f"unsupported model version"
)
self.model = model self.model = model
model_ngrams = cast(int, model.get("__ngrams__", 1)) model_ngrams = cast(int, model.get("__ngrams__", 1))
self.max_ngram_length = min(max_ngram_length, model_ngrams) self.max_ngram_length = min(max_ngram_length, model_ngrams)

View File

@ -8,7 +8,9 @@ CONFIG_T = Union[List[str], int, str]
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]] MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL: def compile(
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
) -> MODEL:
"""Compile a raw model. """Compile a raw model.
Parameters Parameters
@ -47,9 +49,13 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -
categories_by_count[category] = {} categories_by_count[category] = {}
for word in text: for word in text:
try: try:
categories_by_count[category][word] += 1 / len(categories[category]) categories_by_count[category][word] += 1 / len(
categories[category]
)
except KeyError: except KeyError:
categories_by_count[category][word] = 1 / len(categories[category]) categories_by_count[category][word] = 1 / len(
categories[category]
)
word_weights: Dict[str, Dict[str, float]] = {} word_weights: Dict[str, Dict[str, float]] = {}
for category, words in categories_by_count.items(): for category, words in categories_by_count.items():
for word, value in words.items(): for word, value in words.items():
@ -63,7 +69,9 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -
total = sum(weights.values()) total = sum(weights.values())
new_weights: List[int] = [] new_weights: List[int] = []
for category in names: for category in names:
new_weights.append(round((weights.get(category, 0) / total) * 65535)) new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
model[word] = new_weights model[word] = new_weights
model["__names__"] = names model["__names__"] = names

View File

@ -2,20 +2,22 @@
import sys import sys
import os import os
import json from typing import List, Dict, Tuple
def pack(directory, print_exceptions=True): def pack(
directory: str, print_exceptions: bool = False
) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]:
paths = os.listdir(directory) paths = os.listdir(directory)
texts = {} texts: Dict[str, List[str]] = {}
exceptions = [] exceptions = []
for path in paths: for path in paths:
texts[path] = [] texts[path] = []
try: try:
for file in os.listdir(os.path.join(sys.argv[1], path)): for file in os.listdir(os.path.join(directory, path)):
try: try:
with open(os.path.join(sys.argv[1], path, file)) as f: with open(os.path.join(directory, path, file)) as f:
texts[path].append(f.read()) texts[path].append(f.read())
except Exception as e: except Exception as e:
exceptions.append((e,)) exceptions.append((e,))
@ -32,10 +34,3 @@ def pack(directory, print_exceptions=True):
raw_model += [{"category": category, "text": i} for i in cat_texts] raw_model += [{"category": category, "text": i} for i in cat_texts]
return raw_model, exceptions return raw_model, exceptions
if len(sys.argv) != 2:
print("usage: pack.py <path>", file=sys.stderr)
exit(1)
print(json.dumps(pack(sys.argv[1])[0]))