Weighting

Weights words based on the standard deviation of the per-word
confidences; closes #5
This commit is contained in:
Samuel Sloniker 2022-03-06 19:28:05 -08:00
parent 4d93b245e8
commit 34af3a8a0a
2 changed files with 51 additions and 4 deletions

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
import gptc.tokenizer, gptc.compiler, gptc.exceptions import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
import warnings import warnings
@ -63,7 +63,10 @@ class Classifier:
probs = {} probs = {}
for word in text: for word in text:
try: try:
for category, value in enumerate(model[word]): weight, weighted_numbers = gptc.weighting.weight(
[i / 65535 for i in model[word]]
)
for category, value in enumerate(weighted_numbers):
try: try:
probs[category] += value probs[category] += value
except KeyError: except KeyError:
@ -71,8 +74,7 @@ class Classifier:
except KeyError: except KeyError:
pass pass
probs = { probs = {
model["__names__"][category]: value / 65535 model["__names__"][category]: value for category, value in probs.items()
for category, value in probs.items()
} }
total = sum(probs.values()) total = sum(probs.values())
probs = {category: value / total for category, value in probs.items()} probs = {category: value / total for category, value in probs.items()}

45
gptc/weighting.py Executable file
View File

@ -0,0 +1,45 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import math
def _mean(numbers):
"""Calculate the mean of a group of numbers
Parameters
----------
numbers : list of int or float
The numbers to calculate the mean of
Returns
-------
float
The mean of the numbers
"""
return sum(numbers) / len(numbers)
def _standard_deviation(numbers):
"""Calculate the standard deviation of a group of numbers
Parameters
----------
numbers : list of int or float
The numbers to calculate the mean of
Returns
-------
float
The standard deviation of the numbers
"""
mean = _mean(numbers)
squared_deviations = [(mean - i) ** 2 for i in numbers]
return math.sqrt(_mean(squared_deviations))
def weight(numbers):
standard_deviation = _standard_deviation(numbers)
weight = standard_deviation * 2
weighted_numbers = [i * weight for i in numbers]
return weight, weighted_numbers