You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
2.6 KiB
88 lines
2.6 KiB
#!/usr/bin/env python3 |
|
# SPDX-License-Identifier: GPL-3.0-or-later |
|
|
|
import argparse |
|
import json |
|
import sys |
|
import gptc |
|
|
|
|
|
def main() -> None: |
|
parser = argparse.ArgumentParser( |
|
description="General Purpose Text Classifier", prog="gptc" |
|
) |
|
subparsers = parser.add_subparsers(dest="subparser_name", required=True) |
|
|
|
compile_parser = subparsers.add_parser( |
|
"compile", help="compile a raw model" |
|
) |
|
compile_parser.add_argument("model", help="raw model to compile") |
|
compile_parser.add_argument( |
|
"out", help="name of file to write compiled model to" |
|
) |
|
compile_parser.add_argument( |
|
"--max-ngram-length", |
|
"-n", |
|
help="maximum ngram length", |
|
type=int, |
|
default=1, |
|
) |
|
compile_parser.add_argument( |
|
"--min-count", |
|
"-c", |
|
help="minimum use count for word/ngram to be included in model", |
|
type=int, |
|
default=1, |
|
) |
|
|
|
classify_parser = subparsers.add_parser("classify", help="classify text") |
|
classify_parser.add_argument("model", help="compiled model to use") |
|
classify_parser.add_argument( |
|
"--max-ngram-length", |
|
"-n", |
|
help="maximum ngram length", |
|
type=int, |
|
default=1, |
|
) |
|
|
|
check_parser = subparsers.add_parser( |
|
"check", help="check one word or ngram in model" |
|
) |
|
check_parser.add_argument("model", help="compiled model to use") |
|
check_parser.add_argument("token", help="token or ngram to check") |
|
|
|
pack_parser = subparsers.add_parser( |
|
"pack", help="pack a model from a directory" |
|
) |
|
pack_parser.add_argument("model", help="directory containing model") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.subparser_name == "compile": |
|
with open(args.model, "r", encoding="utf-8") as input_file: |
|
model = json.load(input_file) |
|
|
|
with open(args.out, "wb+") as output_file: |
|
gptc.Model.compile( |
|
model, args.max_ngram_length, args.min_count |
|
).serialize(output_file) |
|
elif args.subparser_name == "classify": |
|
with open(args.model, "rb") as model_file: |
|
model = gptc.Model.deserialize(model_file) |
|
|
|
if sys.stdin.isatty(): |
|
text = input("Text to analyse: ") |
|
else: |
|
text = sys.stdin.read() |
|
|
|
print(json.dumps(model.confidence(text, args.max_ngram_length))) |
|
elif args.subparser_name == "check": |
|
with open(args.model, "rb") as model_file: |
|
model = gptc.Model.deserialize(model_file) |
|
print(json.dumps(model.get(args.token))) |
|
else: |
|
print(json.dumps(gptc.pack(args.model, True)[0])) |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|