Compare commits

..

No commits in common. "12f97ae765e1fba105a123ac56d377c7f9d66c6b" and "e4eb322aa7e0e134ddbced6aa38970820a44f60b" have entirely different histories.

3 changed files with 29 additions and 20 deletions

View File

@ -41,30 +41,39 @@ def compile(
word_counts: Dict[int, Dict[str, int]] = {} word_counts: Dict[int, Dict[str, int]] = {}
names = tuple(categories.keys()) names = []
for category, text in categories.items(): for category, text in categories.items():
if not category in names:
names.append(category)
for word in text: for word in text:
if word in word_counts: try:
try: counts_for_word = word_counts[word]
word_counts[word][category] += 1 except KeyError:
except KeyError: counts_for_word = {}
word_counts[word][category] = 1 word_counts[word] = counts_for_word
else:
word_counts[word] = {category: 1}
category_lengths = { try:
category: len(text) for category, text in categories.items() word_counts[word][category] += 1
except KeyError:
word_counts[word][category] = 1
word_counts = {
word: counts
for word, counts in word_counts.items()
if sum(counts.values()) >= min_count
} }
word_weights: Dict[int, Dict[str, float]] = { word_weights: Dict[int, Dict[str, float]] = {}
word: { for word, values in word_counts.items():
category: value / category_lengths[category] for category, value in values.items():
for category, value in values.items() try:
} word_weights[word][category] = value / len(categories[category])
for word, values in word_counts.items() except KeyError:
if sum(values.values()) >= min_count word_weights[word] = {
} category: value / len(categories[category])
}
model: Dict[int, List[int]] = {} model: Dict[int, List[int]] = {}
for word, weights in word_weights.items(): for word, weights in word_weights.items():

View File

@ -62,7 +62,7 @@ class Model:
} }
return probs return probs
def get(self, token: str) -> Dict[str, float]: def get(self, token):
try: try:
weights = self.weights[ weights = self.weights[
gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token)) gptc.tokenizer.hash_single(gptc.tokenizer.normalize(token))

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "gptc" name = "gptc"
version = "3.1.1" version = "3.1.0"
description = "General-purpose text classifier" description = "General-purpose text classifier"
readme = "README.md" readme = "README.md"
authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}] authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}]