Compare commits
1 Commits
f38f4ca801
...
8a1cb6105e
Author | SHA1 | Date | |
---|---|---|---|
8a1cb6105e |
|
@ -28,8 +28,8 @@ def compile(
|
|||
|
||||
"""
|
||||
|
||||
word_counts: Dict[int, Dict[str, int]] = {}
|
||||
category_lengths: Dict[str, int] = {}
|
||||
word_counts: Dict[int, Dict[int, int]] = {}
|
||||
category_lengths: Dict[int, int] = {}
|
||||
names: List[str] = []
|
||||
|
||||
for portion in raw_model:
|
||||
|
@ -37,10 +37,12 @@ def compile(
|
|||
gptc.tokenizer.tokenize(portion["text"], max_ngram_length),
|
||||
hash_algorithm,
|
||||
)
|
||||
category = portion["category"]
|
||||
category_name = portion["category"]
|
||||
|
||||
if not category in names:
|
||||
names.append(category)
|
||||
if not category_name in names:
|
||||
names.append(category_name)
|
||||
|
||||
category = names.index(category_name)
|
||||
|
||||
category_lengths[category] = category_lengths.get(category, 0) + len(
|
||||
text
|
||||
|
@ -64,7 +66,7 @@ def compile(
|
|||
}
|
||||
total = sum(weights.values())
|
||||
new_weights: List[int] = []
|
||||
for category in names:
|
||||
for category in range(len(names)):
|
||||
new_weights.append(
|
||||
round((weights.get(category, 0) / total) * 65535)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user