Format code
This commit is contained in:
parent
4ddeefad07
commit
5378be9418
|
@ -9,10 +9,16 @@ import gptc
|
||||||
parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
|
parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
|
||||||
parser.add_argument("model", help="model to use")
|
parser.add_argument("model", help="model to use")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-c", "--compile", help="compile raw model model to outfile", metavar="outfile"
|
"-c",
|
||||||
|
"--compile",
|
||||||
|
help="compile raw model model to outfile",
|
||||||
|
metavar="outfile",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-j", "--confidence", help="output confidence dict in json", action="store_true"
|
"-j",
|
||||||
|
"--confidence",
|
||||||
|
help="output confidence dict in json",
|
||||||
|
action="store_true",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,9 @@ class Classifier:
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
if model.get("__version__", 0) != 3:
|
if model.get("__version__", 0) != 3:
|
||||||
raise gptc.exceptions.UnsupportedModelError(f"unsupported model version")
|
raise gptc.exceptions.UnsupportedModelError(
|
||||||
|
f"unsupported model version"
|
||||||
|
)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def confidence(self, text):
|
def confidence(self, text):
|
||||||
|
@ -57,7 +59,8 @@ class Classifier:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
probs = {
|
probs = {
|
||||||
model["__names__"][category]: value for category, value in probs.items()
|
model["__names__"][category]: value
|
||||||
|
for category, value in probs.items()
|
||||||
}
|
}
|
||||||
total = sum(probs.values())
|
total = sum(probs.values())
|
||||||
probs = {category: value / total for category, value in probs.items()}
|
probs = {category: value / total for category, value in probs.items()}
|
||||||
|
|
|
@ -39,9 +39,13 @@ def compile(raw_model):
|
||||||
categories_by_count[category] = {}
|
categories_by_count[category] = {}
|
||||||
for word in text:
|
for word in text:
|
||||||
try:
|
try:
|
||||||
categories_by_count[category][word] += 1 / len(categories[category])
|
categories_by_count[category][word] += 1 / len(
|
||||||
|
categories[category]
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
categories_by_count[category][word] = 1 / len(categories[category])
|
categories_by_count[category][word] = 1 / len(
|
||||||
|
categories[category]
|
||||||
|
)
|
||||||
word_weights = {}
|
word_weights = {}
|
||||||
for category, words in categories_by_count.items():
|
for category, words in categories_by_count.items():
|
||||||
for word, value in words.items():
|
for word, value in words.items():
|
||||||
|
@ -55,7 +59,9 @@ def compile(raw_model):
|
||||||
total = sum(weights.values())
|
total = sum(weights.values())
|
||||||
model[word] = []
|
model[word] = []
|
||||||
for category in names:
|
for category in names:
|
||||||
model[word].append(round((weights.get(category, 0) / total) * 65535))
|
model[word].append(
|
||||||
|
round((weights.get(category, 0) / total) * 65535)
|
||||||
|
)
|
||||||
|
|
||||||
model["__names__"] = names
|
model["__names__"] = names
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user