Reformat code with black
This commit is contained in:
parent
252cbaeb9d
commit
4b1e82514f
|
@ -5,20 +5,24 @@ import sys
|
||||||
import gptc
|
import gptc
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
|
parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
|
||||||
parser.add_argument('model', help='model to use')
|
parser.add_argument("model", help="model to use")
|
||||||
parser.add_argument('-c', '--compile', help='compile raw model model to outfile', metavar='outfile')
|
parser.add_argument(
|
||||||
parser.add_argument('-j', '--confidence', help='output confidence dict in json', action='store_true')
|
"-c", "--compile", help="compile raw model model to outfile", metavar="outfile"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-j", "--confidence", help="output confidence dict in json", action="store_true"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.model, 'r') as f:
|
with open(args.model, "r") as f:
|
||||||
raw_model = json.load(f)
|
raw_model = json.load(f)
|
||||||
if args.compile:
|
if args.compile:
|
||||||
with open(args.compile, 'w+') as f:
|
with open(args.compile, "w+") as f:
|
||||||
json.dump(gptc.compile(raw_model), f)
|
json.dump(gptc.compile(raw_model), f)
|
||||||
else:
|
else:
|
||||||
classifier = gptc.Classifier(raw_model)
|
classifier = gptc.Classifier(raw_model)
|
||||||
if sys.stdin.isatty():
|
if sys.stdin.isatty():
|
||||||
text = input('Text to analyse: ')
|
text = input("Text to analyse: ")
|
||||||
else:
|
else:
|
||||||
text = sys.stdin.read()
|
text = sys.stdin.read()
|
||||||
if args.confidence:
|
if args.confidence:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions
|
import gptc.tokenizer, gptc.compiler, gptc.exceptions
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class Classifier:
|
class Classifier:
|
||||||
"""A text classifier.
|
"""A text classifier.
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ class Classifier:
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
try:
|
try:
|
||||||
model_version = model['__version__']
|
model_version = model["__version__"]
|
||||||
except:
|
except:
|
||||||
model_version = 1
|
model_version = 1
|
||||||
|
|
||||||
|
@ -27,11 +28,15 @@ class Classifier:
|
||||||
else:
|
else:
|
||||||
# The model is an unsupported version
|
# The model is an unsupported version
|
||||||
try:
|
try:
|
||||||
raw_model = model['__raw__']
|
raw_model = model["__raw__"]
|
||||||
except:
|
except:
|
||||||
raise gptc.exceptions.UnsupportedModelError('this model is unsupported and does not contain a raw model for recompiling')
|
raise gptc.exceptions.UnsupportedModelError(
|
||||||
|
"this model is unsupported and does not contain a raw model for recompiling"
|
||||||
|
)
|
||||||
|
|
||||||
warnings.warn("model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future")
|
warnings.warn(
|
||||||
|
"model needed to be recompiled on-the-fly; please re-compile it and use the new compiled model in the future"
|
||||||
|
)
|
||||||
self.model = gptc.compiler.compile(raw_model)
|
self.model = gptc.compiler.compile(raw_model)
|
||||||
|
|
||||||
def confidence(self, text):
|
def confidence(self, text):
|
||||||
|
@ -63,7 +68,10 @@ class Classifier:
|
||||||
probs[category] = value
|
probs[category] = value
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
probs = {model['__names__'][category]: value/65535 for category, value in probs.items()}
|
probs = {
|
||||||
|
model["__names__"][category]: value / 65535
|
||||||
|
for category, value in probs.items()
|
||||||
|
}
|
||||||
total = sum(probs.values())
|
total = sum(probs.values())
|
||||||
probs = {category: value / total for category, value in probs.items()}
|
probs = {category: value / total for category, value in probs.items()}
|
||||||
return probs
|
return probs
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import gptc.tokenizer
|
import gptc.tokenizer
|
||||||
|
|
||||||
|
|
||||||
def compile(raw_model):
|
def compile(raw_model):
|
||||||
"""Compile a raw model.
|
"""Compile a raw model.
|
||||||
|
|
||||||
|
@ -18,8 +19,8 @@ def compile(raw_model):
|
||||||
categories = {}
|
categories = {}
|
||||||
|
|
||||||
for portion in raw_model:
|
for portion in raw_model:
|
||||||
text = gptc.tokenizer.tokenize(portion['text'])
|
text = gptc.tokenizer.tokenize(portion["text"])
|
||||||
category = portion['category']
|
category = portion["category"]
|
||||||
try:
|
try:
|
||||||
categories[category] += text
|
categories[category] += text
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -54,9 +55,9 @@ def compile(raw_model):
|
||||||
for category in names:
|
for category in names:
|
||||||
model[word].append(round((weights.get(category, 0) / total) * 65535))
|
model[word].append(round((weights.get(category, 0) / total) * 65535))
|
||||||
|
|
||||||
model['__names__'] = names
|
model["__names__"] = names
|
||||||
|
|
||||||
model['__version__'] = 3
|
model["__version__"] = 3
|
||||||
model['__raw__'] = raw_model
|
model["__raw__"] = raw_model
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
class GPTCError(BaseException):
|
class GPTCError(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelError(GPTCError):
|
class ModelError(GPTCError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedModelError(ModelError):
|
class UnsupportedModelError(ModelError):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import json
|
import json
|
||||||
|
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) != 2:
|
||||||
print('usage: pack.py <path>', file=sys.stderr)
|
print("usage: pack.py <path>", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
paths = os.listdir(sys.argv[1])
|
paths = os.listdir(sys.argv[1])
|
||||||
|
@ -24,6 +24,6 @@ for path in paths:
|
||||||
raw_model = []
|
raw_model = []
|
||||||
|
|
||||||
for category, cat_texts in texts.items():
|
for category, cat_texts in texts.items():
|
||||||
raw_model += [{'category': category, 'text': i} for i in cat_texts]
|
raw_model += [{"category": category, "text": i} for i in cat_texts]
|
||||||
|
|
||||||
print(json.dumps(raw_model))
|
print(json.dumps(raw_model))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user