Refactor word count dict in compiler
This makes future changes to the algorithm much simpler.
This commit is contained in:
parent
aea35ad059
commit
af1d1749d2
|
@ -38,7 +38,7 @@ def compile(
|
||||||
except KeyError:
|
except KeyError:
|
||||||
categories[category] = text
|
categories[category] = text
|
||||||
|
|
||||||
categories_by_count: Dict[str, Dict[str, float]] = {}
|
word_counts: Dict[str, Dict[str, float]] = {}
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
|
|
||||||
|
@ -46,23 +46,27 @@ def compile(
|
||||||
if not category in names:
|
if not category in names:
|
||||||
names.append(category)
|
names.append(category)
|
||||||
|
|
||||||
categories_by_count[category] = {}
|
|
||||||
for word in text:
|
for word in text:
|
||||||
try:
|
try:
|
||||||
categories_by_count[category][word] += 1 / len(
|
counts_for_word = word_counts[word]
|
||||||
categories[category]
|
|
||||||
)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
categories_by_count[category][word] = 1 / len(
|
counts_for_word = {}
|
||||||
categories[category]
|
word_counts[word] = counts_for_word
|
||||||
)
|
|
||||||
word_weights: Dict[str, Dict[str, float]] = {}
|
|
||||||
for category, words in categories_by_count.items():
|
|
||||||
for word, value in words.items():
|
|
||||||
try:
|
try:
|
||||||
word_weights[word][category] = value
|
word_counts[word][category] += 1
|
||||||
except KeyError:
|
except KeyError:
|
||||||
word_weights[word] = {category: value}
|
word_counts[word][category] = 1
|
||||||
|
|
||||||
|
word_weights: Dict[str, Dict[str, float]] = {}
|
||||||
|
for word, values in word_counts.items():
|
||||||
|
for category, value in values.items():
|
||||||
|
try:
|
||||||
|
word_weights[word][category] = value / len(categories[category])
|
||||||
|
except KeyError:
|
||||||
|
word_weights[word] = {
|
||||||
|
category: value / len(categories[category])
|
||||||
|
}
|
||||||
|
|
||||||
model: MODEL = {}
|
model: MODEL = {}
|
||||||
for word, weights in word_weights.items():
|
for word, weights in word_weights.items():
|
||||||
|
@ -80,3 +84,4 @@ def compile(
|
||||||
model["__emoji__"] = int(gptc.tokenizer.has_emoji)
|
model["__emoji__"] = int(gptc.tokenizer.has_emoji)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user