Fix emoji handling

This commit is contained in:
Samuel Sloniker 2022-07-20 14:05:36 -07:00
parent 185692790f
commit 9538cf8c22
4 changed files with 12 additions and 5 deletions

View File

@ -33,7 +33,9 @@ class Classifier:
self.model = model
model_ngrams = cast(int, model.get("__ngrams__", 1))
self.max_ngram_length = min(max_ngram_length, model_ngrams)
self.has_emoji = gptc.tokenizer.has_emoji and gptc.model_info.model_has_emoji(model)
self.has_emoji = (
gptc.tokenizer.has_emoji and gptc.model_info.model_has_emoji(model)
)
def confidence(self, text: str) -> Dict[str, float]:
"""Classify text with confidence.
@ -53,7 +55,9 @@ class Classifier:
model = self.model
tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
tokens = gptc.tokenizer.tokenize(
text, self.max_ngram_length, self.has_emoji
)
numbered_probs: Dict[int, float] = {}
for word in tokens:
try:

View File

@ -77,5 +77,6 @@ def compile(
model["__names__"] = names
model["__ngrams__"] = max_ngram_length
model["__version__"] = 3
model["__emoji__"] = int(tokenizer.has_emoji)
return model

View File

@ -5,4 +5,4 @@ from typing import Dict, Union, cast, List
def model_has_emoji(model: gptc.compiler.MODEL) -> bool:
return cast(int, model.get("__emoji__]", 0)) == 1
return cast(int, model.get("__emoji__", 0)) == 1

View File

@ -10,11 +10,13 @@ except ImportError:
has_emoji = False
def tokenize(text: str, max_ngram_length: int = 1) -> List[str]:
def tokenize(
text: str, max_ngram_length: int = 1, use_emoji: bool = True
) -> List[str]:
"""Convert a string to a list of lemmas."""
converted_text: Union[str, List[str]] = text.lower()
if has_emoji:
if has_emoji and use_emoji:
parts = []
highest_end = 0
for emoji_part in emoji.emoji_list(text):