Format code

This commit is contained in:
Samuel Sloniker 2022-05-20 17:16:00 -07:00
parent 4ddeefad07
commit 5378be9418
3 changed files with 22 additions and 7 deletions

View File

@ -9,10 +9,16 @@ import gptc
parser = argparse.ArgumentParser(description="General Purpose Text Classifier") parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
parser.add_argument("model", help="model to use") parser.add_argument("model", help="model to use")
parser.add_argument( parser.add_argument(
"-c", "--compile", help="compile raw model model to outfile", metavar="outfile" "-c",
"--compile",
help="compile raw model model to outfile",
metavar="outfile",
) )
parser.add_argument( parser.add_argument(
"-j", "--confidence", help="output confidence dict in json", action="store_true" "-j",
"--confidence",
help="output confidence dict in json",
action="store_true",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -21,7 +21,9 @@ class Classifier:
def __init__(self, model): def __init__(self, model):
if model.get("__version__", 0) != 3: if model.get("__version__", 0) != 3:
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version") raise gptc.exceptions.UnsupportedModelError(
f"unsupported model version"
)
self.model = model self.model = model
def confidence(self, text): def confidence(self, text):
@ -57,7 +59,8 @@ class Classifier:
except KeyError: except KeyError:
pass pass
probs = { probs = {
model["__names__"][category]: value for category, value in probs.items() model["__names__"][category]: value
for category, value in probs.items()
} }
total = sum(probs.values()) total = sum(probs.values())
probs = {category: value / total for category, value in probs.items()} probs = {category: value / total for category, value in probs.items()}

View File

@ -39,9 +39,13 @@ def compile(raw_model):
categories_by_count[category] = {} categories_by_count[category] = {}
for word in text: for word in text:
try: try:
categories_by_count[category][word] += 1 / len(categories[category]) categories_by_count[category][word] += 1 / len(
categories[category]
)
except KeyError: except KeyError:
categories_by_count[category][word] = 1 / len(categories[category]) categories_by_count[category][word] = 1 / len(
categories[category]
)
word_weights = {} word_weights = {}
for category, words in categories_by_count.items(): for category, words in categories_by_count.items():
for word, value in words.items(): for word, value in words.items():
@ -55,7 +59,9 @@ def compile(raw_model):
total = sum(weights.values()) total = sum(weights.values())
model[word] = [] model[word] = []
for category in names: for category in names:
model[word].append(round((weights.get(category, 0) / total) * 65535)) model[word].append(
round((weights.get(category, 0) / total) * 65535)
)
model["__names__"] = names model["__names__"] = names