Make more of compilation numeric

This commit is contained in:
Samuel Sloniker 2023-01-04 19:07:58 -08:00
parent 071656c2d2
commit 8a1cb6105e
Signed by: kj7rrv
GPG Key ID: 1BB4029E66285A62

View File

@ -28,8 +28,8 @@ def compile(
""" """
word_counts: Dict[int, Dict[str, int]] = {} word_counts: Dict[int, Dict[int, int]] = {}
category_lengths: Dict[str, int] = {} category_lengths: Dict[int, int] = {}
names: List[str] = [] names: List[str] = []
for portion in raw_model: for portion in raw_model:
@ -37,10 +37,12 @@ def compile(
gptc.tokenizer.tokenize(portion["text"], max_ngram_length), gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
hash_algorithm, hash_algorithm,
) )
category = portion["category"] category_name = portion["category"]
if not category in names: if not category_name in names:
names.append(category) names.append(category_name)
category = names.index(category_name)
category_lengths[category] = category_lengths.get(category, 0) + len( category_lengths[category] = category_lengths.get(category, 0) + len(
text text
@ -64,7 +66,7 @@ def compile(
} }
total = sum(weights.values()) total = sum(weights.values())
new_weights: List[int] = [] new_weights: List[int] = []
for category in names: for category in range(len(names)):
new_weights.append( new_weights.append(
round((weights.get(category, 0) / total) * 65535) round((weights.get(category, 0) / total) * 65535)
) )