Reformat code with black

This commit is contained in:
Samuel Sloniker 2022-03-05 09:42:52 -08:00
parent 252cbaeb9d
commit 4b1e82514f
5 changed files with 39 additions and 24 deletions

View File

@ -5,20 +5,24 @@ import sys
import gptc import gptc
parser = argparse.ArgumentParser(description="General Purpose Text Classifier") parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
parser.add_argument('model', help='model to use') parser.add_argument("model", help="model to use")
parser.add_argument('-c', '--compile', help='compile raw model model to outfile', metavar='outfile') parser.add_argument(
parser.add_argument('-j', '--confidence', help='output confidence dict in json', action='store_true') "-c", "--compile", help="compile raw model model to outfile", metavar="outfile"
)
parser.add_argument(
"-j", "--confidence", help="output confidence dict in json", action="store_true"
)
args = parser.parse_args() args = parser.parse_args()
with open(args.model, 'r') as f: with open(args.model, "r") as f:
raw_model = json.load(f) raw_model = json.load(f)
if args.compile: if args.compile:
with open(args.compile, 'w+') as f: with open(args.compile, "w+") as f:
json.dump(gptc.compile(raw_model), f) json.dump(gptc.compile(raw_model), f)
else: else:
classifier = gptc.Classifier(raw_model) classifier = gptc.Classifier(raw_model)
if sys.stdin.isatty(): if sys.stdin.isatty():
text = input('Text to analyse: ') text = input("Text to analyse: ")
else: else:
text = sys.stdin.read() text = sys.stdin.read()
if args.confidence: if args.confidence:

View File

@ -1,6 +1,7 @@
import gptc.tokenizer, gptc.compiler, gptc.exceptions import gptc.tokenizer, gptc.compiler, gptc.exceptions
import warnings import warnings
class Classifier: class Classifier:
"""A text classifier. """A text classifier.
@ -18,7 +19,7 @@ class Classifier:
def __init__(self, model): def __init__(self, model):
try: try:
model_version = model['__version__'] model_version = model["__version__"]
except: except:
model_version = 1 model_version = 1
@ -27,11 +28,15 @@ class Classifier:
else: else:
# The model is an unsupported version # The model is an unsupported version
try: try:
raw_model = model['__raw__'] raw_model = model["__raw__"]
except: except:
raise gptc.exceptions.UnsupportedModelError('this model is unsupported and does not contain a raw model for recompiling') raise gptc.exceptions.UnsupportedModelError(
"this model is unsupported and does not contain a raw model for recompiling"
)
warnings.warn("model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future") warnings.warn(
"model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future"
)
self.model = gptc.compiler.compile(raw_model) self.model = gptc.compiler.compile(raw_model)
def confidence(self, text): def confidence(self, text):
@ -63,9 +68,12 @@ class Classifier:
probs[category] = value probs[category] = value
except KeyError: except KeyError:
pass pass
probs = {model['__names__'][category]: value/65535 for category, value in probs.items()} probs = {
model["__names__"][category]: value / 65535
for category, value in probs.items()
}
total = sum(probs.values()) total = sum(probs.values())
probs = {category: value/total for category, value in probs.items()} probs = {category: value / total for category, value in probs.items()}
return probs return probs
def classify(self, text): def classify(self, text):

View File

@ -1,5 +1,6 @@
import gptc.tokenizer import gptc.tokenizer
def compile(raw_model): def compile(raw_model):
"""Compile a raw model. """Compile a raw model.
@ -18,8 +19,8 @@ def compile(raw_model):
categories = {} categories = {}
for portion in raw_model: for portion in raw_model:
text = gptc.tokenizer.tokenize(portion['text']) text = gptc.tokenizer.tokenize(portion["text"])
category = portion['category'] category = portion["category"]
try: try:
categories[category] += text categories[category] += text
except KeyError: except KeyError:
@ -36,27 +37,27 @@ def compile(raw_model):
categories_by_count[category] = {} categories_by_count[category] = {}
for word in text: for word in text:
try: try:
categories_by_count[category][word] += 1/len(categories[category]) categories_by_count[category][word] += 1 / len(categories[category])
except KeyError: except KeyError:
categories_by_count[category][word] = 1/len(categories[category]) categories_by_count[category][word] = 1 / len(categories[category])
word_weights = {} word_weights = {}
for category, words in categories_by_count.items(): for category, words in categories_by_count.items():
for word, value in words.items(): for word, value in words.items():
try: try:
word_weights[word][category] = value word_weights[word][category] = value
except KeyError: except KeyError:
word_weights[word] = {category:value} word_weights[word] = {category: value}
model = {} model = {}
for word, weights in word_weights.items(): for word, weights in word_weights.items():
total = sum(weights.values()) total = sum(weights.values())
model[word] = [] model[word] = []
for category in names: for category in names:
model[word].append(round((weights.get(category, 0)/total)*65535)) model[word].append(round((weights.get(category, 0) / total) * 65535))
model['__names__'] = names model["__names__"] = names
model['__version__'] = 3 model["__version__"] = 3
model['__raw__'] = raw_model model["__raw__"] = raw_model
return model return model

View File

@ -1,8 +1,10 @@
class GPTCError(BaseException): class GPTCError(BaseException):
pass pass
class ModelError(GPTCError): class ModelError(GPTCError):
pass pass
class UnsupportedModelError(ModelError): class UnsupportedModelError(ModelError):
pass pass

View File

@ -3,7 +3,7 @@ import os
import json import json
if len(sys.argv) != 2: if len(sys.argv) != 2:
print('usage: pack.py <path>', file=sys.stderr) print("usage: pack.py <path>", file=sys.stderr)
exit(1) exit(1)
paths = os.listdir(sys.argv[1]) paths = os.listdir(sys.argv[1])
@ -24,6 +24,6 @@ for path in paths:
raw_model = [] raw_model = []
for category, cat_texts in texts.items(): for category, cat_texts in texts.items():
raw_model += [{'category': category, 'text': i} for i in cat_texts] raw_model += [{"category": category, "text": i} for i in cat_texts]
print(json.dumps(raw_model)) print(json.dumps(raw_model))