Compare commits

..

No commits in common. "ff8cba84c76cbdecc05f04ecdcb3de3fbebebaf6" and "e711767d248d9270be46bd4d3bbbb7821e85ae57" have entirely different histories.

5 changed files with 22 additions and 40 deletions

View File

@ -4,7 +4,6 @@
from gptc.compiler import compile as compile
from gptc.classifier import Classifier as Classifier
from gptc.pack import pack as pack
from gptc.exceptions import (
GPTCError as GPTCError,
ModelError as ModelError,

View File

@ -13,9 +13,7 @@ def main() -> None:
)
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser(
"compile", help="compile a raw model"
)
compile_parser = subparsers.add_parser("compile", help="compile a raw model")
compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument(
"--max-ngram-length",
@ -48,22 +46,14 @@ def main() -> None:
action="store_true",
)
pack_parser = subparsers.add_parser(
"pack", help="pack a model from a directory"
)
pack_parser.add_argument("model", help="directory containing model")
args = parser.parse_args()
with open(args.model, "r") as f:
model = json.load(f)
if args.subparser_name == "compile":
with open(args.model, "r") as f:
model = json.load(f)
print(json.dumps(gptc.compile(model, args.max_ngram_length)))
elif args.subparser_name == "classify":
with open(args.model, "r") as f:
model = json.load(f)
else:
classifier = gptc.Classifier(model, args.max_ngram_length)
if sys.stdin.isatty():
@ -75,8 +65,6 @@ def main() -> None:
print(classifier.classify(text))
else:
print(json.dumps(classifier.confidence(text)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))
if __name__ == "__main__":

View File

@ -27,9 +27,7 @@ class Classifier:
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError(
f"unsupported model version"
)
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
self.model = model
model_ngrams = cast(int, model.get("__ngrams__", 1))
self.max_ngram_length = min(max_ngram_length, model_ngrams)

View File

@ -8,9 +8,7 @@ CONFIG_T = Union[List[str], int, str]
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
def compile(
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
) -> MODEL:
def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL:
"""Compile a raw model.
Parameters
@ -49,13 +47,9 @@ def compile(
categories_by_count[category] = {}
for word in text:
try:
categories_by_count[category][word] += 1 / len(
categories[category]
)
categories_by_count[category][word] += 1 / len(categories[category])
except KeyError:
categories_by_count[category][word] = 1 / len(
categories[category]
)
categories_by_count[category][word] = 1 / len(categories[category])
word_weights: Dict[str, Dict[str, float]] = {}
for category, words in categories_by_count.items():
for word, value in words.items():
@ -69,9 +63,7 @@ def compile(
total = sum(weights.values())
new_weights: List[int] = []
for category in names:
new_weights.append(
round((weights.get(category, 0) / total) * 65535)
)
new_weights.append(round((weights.get(category, 0) / total) * 65535))
model[word] = new_weights
model["__names__"] = names

View File

@ -2,22 +2,20 @@
import sys
import os
from typing import List, Dict, Tuple
import json
def pack(
directory: str, print_exceptions: bool = False
) -> Tuple[List[Dict[str, str]], List[Tuple[Exception]]]:
def pack(directory, print_exceptions=True):
paths = os.listdir(directory)
texts: Dict[str, List[str]] = {}
texts = {}
exceptions = []
for path in paths:
texts[path] = []
try:
for file in os.listdir(os.path.join(directory, path)):
for file in os.listdir(os.path.join(sys.argv[1], path)):
try:
with open(os.path.join(directory, path, file)) as f:
with open(os.path.join(sys.argv[1], path, file)) as f:
texts[path].append(f.read())
except Exception as e:
exceptions.append((e,))
@ -34,3 +32,10 @@ def pack(
raw_model += [{"category": category, "text": i} for i in cat_texts]
return raw_model, exceptions
if len(sys.argv) != 2:
print("usage: pack.py <path>", file=sys.stderr)
exit(1)
print(json.dumps(pack(sys.argv[1])[0]))