diff --git a/gptc/classifier.py b/gptc/classifier.py index 1a0e541..2b655b3 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import gptc.tokenizer, gptc.compiler, gptc.exceptions +import gptc.tokenizer, gptc.compiler, gptc.exceptions, gptc.weighting import warnings @@ -63,7 +63,10 @@ class Classifier: probs = {} for word in text: try: - for category, value in enumerate(model[word]): + weight, weighted_numbers = gptc.weighting.weight( + [i / 65535 for i in model[word]] + ) + for category, value in enumerate(weighted_numbers): try: probs[category] += value except KeyError: @@ -71,8 +74,7 @@ class Classifier: except KeyError: pass probs = { - model["__names__"][category]: value / 65535 - for category, value in probs.items() + model["__names__"][category]: value for category, value in probs.items() } total = sum(probs.values()) probs = {category: value / total for category, value in probs.items()} diff --git a/gptc/weighting.py b/gptc/weighting.py new file mode 100755 index 0000000..047de6a --- /dev/null +++ b/gptc/weighting.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import math + + +def _mean(numbers): + """Calculate the mean of a group of numbers + + Parameters + ---------- + numbers : list of int or float + The numbers to calculate the mean of + + Returns + ------- + float + The mean of the numbers + """ + return sum(numbers) / len(numbers) + + +def _standard_deviation(numbers): + """Calculate the standard deviation of a group of numbers + + Parameters + ---------- + numbers : list of int or float + The numbers to calculate the mean of + + Returns + ------- + float + The standard deviation of the numbers + + """ + mean = _mean(numbers) + squared_deviations = [(mean - i) ** 2 for i in numbers] + return math.sqrt(_mean(squared_deviations)) + + +def weight(numbers): + standard_deviation = _standard_deviation(numbers) + weight = standard_deviation * 2 + weighted_numbers = [i * weight for i in numbers] + return weight, weighted_numbers