gptc/gptc/__main__.py
2023-04-16 14:27:07 -07:00

89 lines
2.5 KiB
Python

#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-3.0-or-later
import argparse
import json
import sys
import gptc
def main() -> None:
parser = argparse.ArgumentParser(
description="General Purpose Text Classifier", prog="gptc"
)
subparsers = parser.add_subparsers(dest="subparser_name", required=True)
compile_parser = subparsers.add_parser(
"compile", help="compile a raw model"
)
compile_parser.add_argument("model", help="raw model to compile")
compile_parser.add_argument(
"out", help="name of file to write compiled model to"
)
compile_parser.add_argument(
"--max-ngram-length",
"-n",
help="maximum ngram length",
type=int,
default=1,
)
compile_parser.add_argument(
"--min-count",
"-c",
help="minimum use count for word/ngram to be included in model",
type=int,
default=1,
)
classify_parser = subparsers.add_parser("classify", help="classify text")
classify_parser.add_argument("model", help="compiled model to use")
classify_parser.add_argument(
"--max-ngram-length",
"-n",
help="maximum ngram length",
type=int,
default=1,
)
check_parser = subparsers.add_parser(
"check", help="check one word or ngram in model"
)
check_parser.add_argument("model", help="compiled model to use")
check_parser.add_argument("token", help="token or ngram to check")
pack_parser = subparsers.add_parser(
"pack", help="pack a model from a directory"
)
pack_parser.add_argument("model", help="directory containing model")
args = parser.parse_args()
if args.subparser_name == "compile":
with open(args.model, "r") as f:
model = json.load(f)
with open(args.out, "wb+") as f:
gptc.compile(
model, args.max_ngram_length, args.min_count
).serialize(f)
elif args.subparser_name == "classify":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)
if sys.stdin.isatty():
text = input("Text to analyse: ")
else:
text = sys.stdin.read()
print(json.dumps(model.confidence(text, args.max_ngram_length)))
elif args.subparser_name == "check":
with open(args.model, "rb") as f:
model = gptc.deserialize(f)
print(json.dumps(model.get(args.token)))
else:
print(json.dumps(gptc.pack(args.model, True)[0]))
if __name__ == "__main__":
main()