Weighting
Weights words based on the standard deviation of the per-word confidences; closes #5
This commit is contained in:
parent
4d93b245e8
commit
34af3a8a0a
|
@ -1,6 +1,6 @@
|
|||
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||
|
||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions
|
||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
|
||||
import warnings
|
||||
|
||||
|
||||
|
@ -63,7 +63,10 @@ class Classifier:
|
|||
probs = {}
|
||||
for word in text:
|
||||
try:
|
||||
for category, value in enumerate(model[word]):
|
||||
weight, weighted_numbers = gptc.weighting.weight(
|
||||
[i / 65535 for i in model[word]]
|
||||
)
|
||||
for category, value in enumerate(weighted_numbers):
|
||||
try:
|
||||
probs[category] += value
|
||||
except KeyError:
|
||||
|
@ -71,8 +74,7 @@ class Classifier:
|
|||
except KeyError:
|
||||
pass
|
||||
probs = {
|
||||
model["__names__"][category]: value / 65535
|
||||
for category, value in probs.items()
|
||||
model["__names__"][category]: value for category, value in probs.items()
|
||||
}
|
||||
total = sum(probs.values())
|
||||
probs = {category: value / total for category, value in probs.items()}
|
||||
|
|
45
gptc/weighting.py
Executable file
45
gptc/weighting.py
Executable file
|
@ -0,0 +1,45 @@
|
|||
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def _mean(numbers):
|
||||
"""Calculate the mean of a group of numbers
|
||||
|
||||
Parameters
|
||||
----------
|
||||
numbers : list of int or float
|
||||
The numbers to calculate the mean of
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The mean of the numbers
|
||||
"""
|
||||
return sum(numbers) / len(numbers)
|
||||
|
||||
|
||||
def _standard_deviation(numbers):
|
||||
"""Calculate the standard deviation of a group of numbers
|
||||
|
||||
Parameters
|
||||
----------
|
||||
numbers : list of int or float
|
||||
The numbers to calculate the mean of
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The standard deviation of the numbers
|
||||
|
||||
"""
|
||||
mean = _mean(numbers)
|
||||
squared_deviations = [(mean - i) ** 2 for i in numbers]
|
||||
return math.sqrt(_mean(squared_deviations))
|
||||
|
||||
|
||||
def weight(numbers):
|
||||
standard_deviation = _standard_deviation(numbers)
|
||||
weight = standard_deviation * 2
|
||||
weighted_numbers = [i * weight for i in numbers]
|
||||
return weight, weighted_numbers
|
Loading…
Reference in New Issue
Block a user