diff --git a/gptc/__main__.py b/gptc/__main__.py index 6c3e99b..11cd795 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -74,14 +74,12 @@ def main() -> None: else: text = sys.stdin.read() - probabilities = model.confidence(text, args.max_ngram_length) if args.category: - try: - print(sorted(probabilities.items(), key=lambda x: x[1])[-1][0]) - except IndexError: - print(None) + classifier = gptc.Classifier(model, args.max_ngram_length) + print(classifier.classify(text)) else: + probabilities = model.confidence(text, args.max_ngram_length) print(json.dumps(probabilities)) else: print(json.dumps(gptc.pack(args.model, True)[0]))