diff --git a/README.md b/README.md index 6e1a5a9..33f2531 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,10 @@ argument, and it treats the input as a single token or ngram. ### Compiling models - gptc compile [-n ] [-c ] + gptc compile [-n ] [-c ] -This will print the compiled model encoded in binary format to stdout. +This will write the compiled model encoded in binary format to ``. If `-c` is specified, words and ngrams used less than `min_count` times will be excluded from the compiled model. diff --git a/gptc/__main__.py b/gptc/__main__.py index c55bacb..9586299 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -13,8 +13,13 @@ def main() -> None: ) subparsers = parser.add_subparsers(dest="subparser_name", required=True) - compile_parser = subparsers.add_parser("compile", help="compile a raw model") + compile_parser = subparsers.add_parser( + "compile", help="compile a raw model" + ) compile_parser.add_argument("model", help="raw model to compile") + compile_parser.add_argument( + "out", help="name of file to write compiled model to" + ) compile_parser.add_argument( "--max-ngram-length", "-n", @@ -53,11 +58,15 @@ def main() -> None: action="store_true", ) - check_parser = subparsers.add_parser("check", help="check one word or ngram in model") + check_parser = subparsers.add_parser( + "check", help="check one word or ngram in model" + ) check_parser.add_argument("model", help="compiled model to use") check_parser.add_argument("token", help="token or ngram to check") - pack_parser = subparsers.add_parser("pack", help="pack a model from a directory") + pack_parser = subparsers.add_parser( + "pack", help="pack a model from a directory" + ) pack_parser.add_argument("model", help="directory containing model") args = parser.parse_args() @@ -66,7 +75,10 @@ def main() -> None: with open(args.model, "r") as f: model = json.load(f) - gptc.compile(model, args.max_ngram_length, args.min_count).serialize(sys.stdout.buffer) + with open(args.out, "wb+") as f: + gptc.compile( + model, args.max_ngram_length, args.min_count + ).serialize(f) elif args.subparser_name == "classify": with open(args.model, "rb") as f: model = gptc.deserialize(f) @@ -76,7 +88,6 @@ def main() -> None: else: text = sys.stdin.read() - if args.category: classifier = gptc.Classifier(model, args.max_ngram_length) print(classifier.classify(text))