From be543134bc3202de964ad0b2545e740d0d2a87eb Mon Sep 17 00:00:00 2001 From: kj7rrv Date: Wed, 3 Nov 2021 06:38:22 -0700 Subject: [PATCH] Add Classifier.confidence() --- README.md | 17 +++++++++++---- gptc/__main__.py | 6 +++++- gptc/classifier.py | 52 +++++++++++++++++++++++++++++++++------------- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index e512e92..1176443 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,15 @@ GPTC provides both a CLI tool and a Python library. `python -m gptc ` This will prompt for a string and classify it, outputting the category on -stdout (or "None" if it cannot determine -anything). +stdout (or "None" if it cannot determine anything). + +Alternatively, if you need confidence data, use: + + `python -m gptc -j ` + +This will print (in JSON) a dict of the format `{category: probability, +category:probability, ...}` to stdout. + ### Compiling models gptc -c|--compile @@ -19,9 +26,11 @@ anything). ### `gptc.Classifier(model)` Create a `Classifier` object using the given *compiled* model (as a dict, not JSON). +#### `Classifier.confidence(text)` +Classify `text`. Returns a dict of the format `{category: probability, +category:probability, ...}` #### `Classifier.classify(text)` -Classify `text` with GPTC using the model used to instantiate the -`Classifier`. Returns the category into which the text is placed (as a +Classify `text`. Returns the category into which the text is placed (as a string), or `None` when it cannot classify the text. ## `gptc.compile(raw_model)` Compile a raw model (as a list, not JSON) and return the compiled model (as a diff --git a/gptc/__main__.py b/gptc/__main__.py index 5818741..c8d3d7b 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -7,6 +7,7 @@ import gptc parser = argparse.ArgumentParser(description="General Purpose Text Classifier") parser.add_argument('model', help='model to use') parser.add_argument('-c', '--compile', help='compile raw model model to outfile', metavar='outfile') +parser.add_argument('-j', '--confidence', help='output confidence dict in json', action='store_true') args = parser.parse_args() with open(args.model, 'r') as f: @@ -20,4 +21,7 @@ else: text = input('Text to analyse: ') else: text = sys.stdin.read() - print(classifier.classify(text)) + if args.confidence: + print(json.dumps(classifier.confidence(text))) + else: + print(classifier.classify(text)) diff --git a/gptc/classifier.py b/gptc/classifier.py index 1949f99..73a0077 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -34,6 +34,40 @@ class Classifier: warnings.warn("model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future") self.model = gptc.compiler.compile(raw_model) + def confidence(self, text): + """Classify text with confidence. + + Parameters + ---------- + text : str + The text to classify + + Returns + ------- + dict + {category:probability, category:probability...} or {} if no words + matching any categories in the model were found + + """ + + model = self.model + + text = gptc.tokenizer.tokenize(text) + probs = {} + for word in text: + try: + total = sum(model[word].values()) + for category, value in model[word].items(): + try: + probs[category] += value / total + except KeyError: + probs[category] = value / total + except KeyError: + pass + total = sum(probs.values()) + probs = {category: value/total for category, value in probs.items()} + return probs + def classify(self, text): """Classify text. @@ -45,23 +79,11 @@ class Classifier: Returns ------- str or None - The most likely category, or None if no guess was made. + The most likely category, or None if no words matching any + category in the model were found. """ - - model = self.model - - text = gptc.tokenizer.tokenize(text) - probs = {} - for word in text: - try: - for category, value in model[word].items(): - try: - probs[category] += value - except KeyError: - probs[category] = value - except KeyError: - pass + probs = self.confidence(text) try: return sorted(probs.items(), key=lambda x: x[1])[-1][0] except IndexError: