Weighting
Weights words based on the standard deviation of the per-word confidences; closes #5
This commit is contained in:
parent
4d93b245e8
commit
34af3a8a0a
|
@ -1,6 +1,6 @@
|
||||||
# SPDX-License-Identifier: LGPL-3.0-or-later
|
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||||
|
|
||||||
import gptc.tokenizer, gptc.compiler, gptc.exceptions
|
import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +63,10 @@ class Classifier:
|
||||||
probs = {}
|
probs = {}
|
||||||
for word in text:
|
for word in text:
|
||||||
try:
|
try:
|
||||||
for category, value in enumerate(model[word]):
|
weight, weighted_numbers = gptc.weighting.weight(
|
||||||
|
[i / 65535 for i in model[word]]
|
||||||
|
)
|
||||||
|
for category, value in enumerate(weighted_numbers):
|
||||||
try:
|
try:
|
||||||
probs[category] += value
|
probs[category] += value
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -71,8 +74,7 @@ class Classifier:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
probs = {
|
probs = {
|
||||||
model["__names__"][category]: value / 65535
|
model["__names__"][category]: value for category, value in probs.items()
|
||||||
for category, value in probs.items()
|
|
||||||
}
|
}
|
||||||
total = sum(probs.values())
|
total = sum(probs.values())
|
||||||
probs = {category: value / total for category, value in probs.items()}
|
probs = {category: value / total for category, value in probs.items()}
|
||||||
|
|
45
gptc/weighting.py
Executable file
45
gptc/weighting.py
Executable file
|
@ -0,0 +1,45 @@
|
||||||
|
# SPDX-License-Identifier: LGPL-3.0-or-later
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def _mean(numbers):
|
||||||
|
"""Calculate the mean of a group of numbers
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
numbers : list of int or float
|
||||||
|
The numbers to calculate the mean of
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
The mean of the numbers
|
||||||
|
"""
|
||||||
|
return sum(numbers) / len(numbers)
|
||||||
|
|
||||||
|
|
||||||
|
def _standard_deviation(numbers):
|
||||||
|
"""Calculate the standard deviation of a group of numbers
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
numbers : list of int or float
|
||||||
|
The numbers to calculate the mean of
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
The standard deviation of the numbers
|
||||||
|
|
||||||
|
"""
|
||||||
|
mean = _mean(numbers)
|
||||||
|
squared_deviations = [(mean - i) ** 2 for i in numbers]
|
||||||
|
return math.sqrt(_mean(squared_deviations))
|
||||||
|
|
||||||
|
|
||||||
|
def weight(numbers):
|
||||||
|
standard_deviation = _standard_deviation(numbers)
|
||||||
|
weight = standard_deviation * 2
|
||||||
|
weighted_numbers = [i * weight for i in numbers]
|
||||||
|
return weight, weighted_numbers
|
Loading…
Reference in New Issue
Block a user