diff --git a/gptc/__main__.py b/gptc/__main__.py index c8d3d7b..38cc7b0 100644 --- a/gptc/__main__.py +++ b/gptc/__main__.py @@ -5,20 +5,24 @@ import sys import gptc parser = argparse.ArgumentParser(description="General Purpose Text Classifier") -parser.add_argument('model', help='model to use') -parser.add_argument('-c', '--compile', help='compile raw model model to outfile', metavar='outfile') -parser.add_argument('-j', '--confidence', help='output confidence dict in json', action='store_true') +parser.add_argument("model", help="model to use") +parser.add_argument( + "-c", "--compile", help="compile raw model model to outfile", metavar="outfile" +) +parser.add_argument( + "-j", "--confidence", help="output confidence dict in json", action="store_true" +) args = parser.parse_args() -with open(args.model, 'r') as f: +with open(args.model, "r") as f: raw_model = json.load(f) if args.compile: - with open(args.compile, 'w+') as f: + with open(args.compile, "w+") as f: json.dump(gptc.compile(raw_model), f) else: classifier = gptc.Classifier(raw_model) if sys.stdin.isatty(): - text = input('Text to analyse: ') + text = input("Text to analyse: ") else: text = sys.stdin.read() if args.confidence: diff --git a/gptc/classifier.py b/gptc/classifier.py index c424a93..85b7990 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -1,6 +1,7 @@ import gptc.tokenizer, gptc.compiler, gptc.exceptions import warnings + class Classifier: """A text classifier. @@ -18,7 +19,7 @@ class Classifier: def __init__(self, model): try: - model_version = model['__version__'] + model_version = model["__version__"] except: model_version = 1 @@ -27,11 +28,15 @@ class Classifier: else: # The model is an unsupported version try: - raw_model = model['__raw__'] + raw_model = model["__raw__"] except: - raise gptc.exceptions.UnsupportedModelError('this model is unsupported and does not contain a raw model for recompiling') + raise gptc.exceptions.UnsupportedModelError( + "this model is unsupported and does not contain a raw model for recompiling" + ) - warnings.warn("model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future") + warnings.warn( + "model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future" + ) self.model = gptc.compiler.compile(raw_model) def confidence(self, text): @@ -63,9 +68,12 @@ class Classifier: probs[category] = value except KeyError: pass - probs = {model['__names__'][category]: value/65535 for category, value in probs.items()} + probs = { + model["__names__"][category]: value / 65535 + for category, value in probs.items() + } total = sum(probs.values()) - probs = {category: value/total for category, value in probs.items()} + probs = {category: value / total for category, value in probs.items()} return probs def classify(self, text): diff --git a/gptc/compiler.py b/gptc/compiler.py index d7c34ce..a2cfc66 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -1,5 +1,6 @@ import gptc.tokenizer + def compile(raw_model): """Compile a raw model. @@ -18,15 +19,15 @@ def compile(raw_model): categories = {} for portion in raw_model: - text = gptc.tokenizer.tokenize(portion['text']) - category = portion['category'] + text = gptc.tokenizer.tokenize(portion["text"]) + category = portion["category"] try: categories[category] += text except KeyError: categories[category] = text categories_by_count = {} - + names = [] for category, text in categories.items(): @@ -36,27 +37,27 @@ def compile(raw_model): categories_by_count[category] = {} for word in text: try: - categories_by_count[category][word] += 1/len(categories[category]) + categories_by_count[category][word] += 1 / len(categories[category]) except KeyError: - categories_by_count[category][word] = 1/len(categories[category]) + categories_by_count[category][word] = 1 / len(categories[category]) word_weights = {} for category, words in categories_by_count.items(): for word, value in words.items(): try: word_weights[word][category] = value except KeyError: - word_weights[word] = {category:value} + word_weights[word] = {category: value} model = {} for word, weights in word_weights.items(): total = sum(weights.values()) model[word] = [] for category in names: - model[word].append(round((weights.get(category, 0)/total)*65535)) + model[word].append(round((weights.get(category, 0) / total) * 65535)) - model['__names__'] = names + model["__names__"] = names - model['__version__'] = 3 - model['__raw__'] = raw_model + model["__version__"] = 3 + model["__raw__"] = raw_model return model diff --git a/gptc/exceptions.py b/gptc/exceptions.py index c1ebddd..7d504e3 100644 --- a/gptc/exceptions.py +++ b/gptc/exceptions.py @@ -1,8 +1,10 @@ class GPTCError(BaseException): pass + class ModelError(GPTCError): pass + class UnsupportedModelError(ModelError): pass diff --git a/utils/pack.py b/utils/pack.py index 2d8f106..001cc17 100644 --- a/utils/pack.py +++ b/utils/pack.py @@ -3,7 +3,7 @@ import os import json if len(sys.argv) != 2: - print('usage: pack.py ', file=sys.stderr) + print("usage: pack.py ", file=sys.stderr) exit(1) paths = os.listdir(sys.argv[1]) @@ -24,6 +24,6 @@ for path in paths: raw_model = [] for category, cat_texts in texts.items(): - raw_model += [{'category': category, 'text': i} for i in cat_texts] + raw_model += [{"category": category, "text": i} for i in cat_texts] print(json.dumps(raw_model))