This commit is contained in:
Samuel Sloniker 2021-11-16 21:50:59 -08:00
parent 38ec27a3f4
commit 378a23608d
6 changed files with 5 additions and 4 deletions

View File

@ -22,7 +22,7 @@ class Classifier:
except: except:
model_version = 1 model_version = 1
if model_version == 2: if model_version == 3:
self.model = model self.model = model
else: else:
# The model is an unsupported version # The model is an unsupported version
@ -63,8 +63,9 @@ class Classifier:
probs[category] = value probs[category] = value
except KeyError: except KeyError:
pass pass
probs = {model['__names__'][category]: value/65535 for category, value in probs.items()}
total = sum(probs.values()) total = sum(probs.values())
probs = {model['__names__'][category]: value/total for category, value in probs.items()} probs = {category: value/total for category, value in probs.items()}
return probs return probs
def classify(self, text): def classify(self, text):

View File

@ -52,11 +52,11 @@ def compile(raw_model):
total = sum(weights.values()) total = sum(weights.values())
model[word] = [] model[word] = []
for category in names: for category in names:
model[word].append(weights.get(category, 0)/total) model[word].append(round((weights.get(category, 0)/total)*65535))
model['__names__'] = names model['__names__'] = names
model['__version__'] = 2 model['__version__'] = 3
model['__raw__'] = raw_model model['__raw__'] = raw_model
return model return model