Model v3
This commit is contained in:
parent
38ec27a3f4
commit
378a23608d
|
@ -22,7 +22,7 @@ class Classifier:
|
|||
except:
|
||||
model_version = 1
|
||||
|
||||
if model_version == 2:
|
||||
if model_version == 3:
|
||||
self.model = model
|
||||
else:
|
||||
# The model is an unsupported version
|
||||
|
@ -63,8 +63,9 @@ class Classifier:
|
|||
probs[category] = value
|
||||
except KeyError:
|
||||
pass
|
||||
probs = {model['__names__'][category]: value/65535 for category, value in probs.items()}
|
||||
total = sum(probs.values())
|
||||
probs = {model['__names__'][category]: value/total for category, value in probs.items()}
|
||||
probs = {category: value/total for category, value in probs.items()}
|
||||
return probs
|
||||
|
||||
def classify(self, text):
|
|
@ -52,11 +52,11 @@ def compile(raw_model):
|
|||
total = sum(weights.values())
|
||||
model[word] = []
|
||||
for category in names:
|
||||
model[word].append(weights.get(category, 0)/total)
|
||||
model[word].append(round((weights.get(category, 0)/total)*65535))
|
||||
|
||||
model['__names__'] = names
|
||||
|
||||
model['__version__'] = 2
|
||||
model['__version__'] = 3
|
||||
model['__raw__'] = raw_model
|
||||
|
||||
return model
|
Loading…
Reference in New Issue
Block a user