Move pack to main module; format code
This commit is contained in:
parent
e711767d24
commit
5082c2226b
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
from gptc.compiler import compile as compile
|
from gptc.compiler import compile as compile
|
||||||
from gptc.classifier import Classifier as Classifier
|
from gptc.classifier import Classifier as Classifier
|
||||||
|
from gptc.pack import pack as pack
|
||||||
from gptc.exceptions import (
|
from gptc.exceptions import (
|
||||||
GPTCError as GPTCError,
|
GPTCError as GPTCError,
|
||||||
ModelError as ModelError,
|
ModelError as ModelError,
|
||||||
|
|
|
@ -13,7 +13,9 @@ def main() -> None:
|
||||||
)
|
)
|
||||||
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
|
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
|
||||||
|
|
||||||
compile_parser = subparsers.add_parser("compile", help="compile a raw model")
|
compile_parser = subparsers.add_parser(
|
||||||
|
"compile", help="compile a raw model"
|
||||||
|
)
|
||||||
compile_parser.add_argument("model", help="raw model to compile")
|
compile_parser.add_argument("model", help="raw model to compile")
|
||||||
compile_parser.add_argument(
|
compile_parser.add_argument(
|
||||||
"--max-ngram-length",
|
"--max-ngram-length",
|
||||||
|
@ -46,14 +48,22 @@ def main() -> None:
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pack_parser = subparsers.add_parser(
|
||||||
|
"pack", help="pack a model from a directory"
|
||||||
|
)
|
||||||
|
pack_parser.add_argument("model", help="directory containing model")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.model, "r") as f:
|
|
||||||
model = json.load(f)
|
|
||||||
|
|
||||||
if args.subparser_name == "compile":
|
if args.subparser_name == "compile":
|
||||||
|
with open(args.model, "r") as f:
|
||||||
|
model = json.load(f)
|
||||||
|
|
||||||
print(json.dumps(gptc.compile(model, args.max_ngram_length)))
|
print(json.dumps(gptc.compile(model, args.max_ngram_length)))
|
||||||
else:
|
elif args.subparser_name == "classify":
|
||||||
|
with open(args.model, "r") as f:
|
||||||
|
model = json.load(f)
|
||||||
|
|
||||||
classifier = gptc.Classifier(model, args.max_ngram_length)
|
classifier = gptc.Classifier(model, args.max_ngram_length)
|
||||||
|
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
|
@ -65,6 +75,8 @@ def main() -> None:
|
||||||
print(classifier.classify(text))
|
print(classifier.classify(text))
|
||||||
else:
|
else:
|
||||||
print(json.dumps(classifier.confidence(text)))
|
print(json.dumps(classifier.confidence(text)))
|
||||||
|
else:
|
||||||
|
print(json.dumps(gptc.pack(args.model, True)[0]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -27,7 +27,9 @@ class Classifier:
|
||||||
|
|
||||||
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
|
def __init__(self, model: gptc.compiler.MODEL, max_ngram_length: int = 1):
|
||||||
if model.get("__version__", 0) != 3:
|
if model.get("__version__", 0) != 3:
|
||||||
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
|
raise gptc.exceptions.UnsupportedModelError(
|
||||||
|
f"unsupported model version"
|
||||||
|
)
|
||||||
self.model = model
|
self.model = model
|
||||||
model_ngrams = cast(int, model.get("__ngrams__", 1))
|
model_ngrams = cast(int, model.get("__ngrams__", 1))
|
||||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
||||||
|
|
|
@ -8,7 +8,9 @@ CONFIG_T = Union[List[str], int, str]
|
||||||
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
|
MODEL = Dict[str, Union[WEIGHTS_T, CONFIG_T]]
|
||||||
|
|
||||||
|
|
||||||
def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -> MODEL:
|
def compile(
|
||||||
|
raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1
|
||||||
|
) -> MODEL:
|
||||||
"""Compile a raw model.
|
"""Compile a raw model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -47,9 +49,13 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -
|
||||||
categories_by_count[category] = {}
|
categories_by_count[category] = {}
|
||||||
for word in text:
|
for word in text:
|
||||||
try:
|
try:
|
||||||
categories_by_count[category][word] += 1 / len(categories[category])
|
categories_by_count[category][word] += 1 / len(
|
||||||
|
categories[category]
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
categories_by_count[category][word] = 1 / len(categories[category])
|
categories_by_count[category][word] = 1 / len(
|
||||||
|
categories[category]
|
||||||
|
)
|
||||||
word_weights: Dict[str, Dict[str, float]] = {}
|
word_weights: Dict[str, Dict[str, float]] = {}
|
||||||
for category, words in categories_by_count.items():
|
for category, words in categories_by_count.items():
|
||||||
for word, value in words.items():
|
for word, value in words.items():
|
||||||
|
@ -63,7 +69,9 @@ def compile(raw_model: Iterable[Mapping[str, str]], max_ngram_length: int = 1) -
|
||||||
total = sum(weights.values())
|
total = sum(weights.values())
|
||||||
new_weights: List[int] = []
|
new_weights: List[int] = []
|
||||||
for category in names:
|
for category in names:
|
||||||
new_weights.append(round((weights.get(category, 0) / total) * 65535))
|
new_weights.append(
|
||||||
|
round((weights.get(category, 0) / total) * 65535)
|
||||||
|
)
|
||||||
model[word] = new_weights
|
model[word] = new_weights
|
||||||
|
|
||||||
model["__names__"] = names
|
model["__names__"] = names
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
def pack(directory, print_exceptions=True):
|
def pack(directory, print_exceptions=False):
|
||||||
paths = os.listdir(directory)
|
paths = os.listdir(directory)
|
||||||
texts = {}
|
texts = {}
|
||||||
exceptions = []
|
exceptions = []
|
||||||
|
@ -13,9 +12,9 @@ def pack(directory, print_exceptions=True):
|
||||||
for path in paths:
|
for path in paths:
|
||||||
texts[path] = []
|
texts[path] = []
|
||||||
try:
|
try:
|
||||||
for file in os.listdir(os.path.join(sys.argv[1], path)):
|
for file in os.listdir(os.path.join(directory, path)):
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(sys.argv[1], path, file)) as f:
|
with open(os.path.join(directory, path, file)) as f:
|
||||||
texts[path].append(f.read())
|
texts[path].append(f.read())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exceptions.append((e,))
|
exceptions.append((e,))
|
||||||
|
@ -32,10 +31,3 @@ def pack(directory, print_exceptions=True):
|
||||||
raw_model += [{"category": category, "text": i} for i in cat_texts]
|
raw_model += [{"category": category, "text": i} for i in cat_texts]
|
||||||
|
|
||||||
return raw_model, exceptions
|
return raw_model, exceptions
|
||||||
|
|
||||||
|
|
||||||
if len(sys.argv) != 2:
|
|
||||||
print("usage: pack.py <path>", file=sys.stderr)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
print(json.dumps(pack(sys.argv[1])[0]))
|
|
Loading…
Reference in New Issue
Block a user