diff --git a/gptc/classifier.py b/gptc/classifier.py index 22de86c..9736962 100755 --- a/gptc/classifier.py +++ b/gptc/classifier.py @@ -33,7 +33,9 @@ class Classifier: self.model = model model_ngrams = cast(int, model.get("__ngrams__", 1)) self.max_ngram_length = min(max_ngram_length, model_ngrams) - self.has_emoji = gptc.tokenizer.has_emoji and gptc.model_info.model_has_emoji(model) + self.has_emoji = ( + gptc.tokenizer.has_emoji and gptc.model_info.model_has_emoji(model) + ) def confidence(self, text: str) -> Dict[str, float]: """Classify text with confidence. @@ -53,7 +55,9 @@ class Classifier: model = self.model - tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length) + tokens = gptc.tokenizer.tokenize( + text, self.max_ngram_length, self.has_emoji + ) numbered_probs: Dict[int, float] = {} for word in tokens: try: diff --git a/gptc/compiler.py b/gptc/compiler.py index 05b793b..185ab55 100755 --- a/gptc/compiler.py +++ b/gptc/compiler.py @@ -77,5 +77,6 @@ def compile( model["__names__"] = names model["__ngrams__"] = max_ngram_length model["__version__"] = 3 + model["__emoji__"] = int(tokenizer.has_emoji) return model diff --git a/gptc/model_info.py b/gptc/model_info.py index be9d3b1..6372e6c 100755 --- a/gptc/model_info.py +++ b/gptc/model_info.py @@ -5,4 +5,4 @@ from typing import Dict, Union, cast, List def model_has_emoji(model: gptc.compiler.MODEL) -> bool: - return cast(int, model.get("__emoji__]", 0)) == 1 + return cast(int, model.get("__emoji__", 0)) == 1 diff --git a/gptc/tokenizer.py b/gptc/tokenizer.py index 7763e3c..fe09223 100644 --- a/gptc/tokenizer.py +++ b/gptc/tokenizer.py @@ -10,11 +10,13 @@ except ImportError: has_emoji = False -def tokenize(text: str, max_ngram_length: int = 1) -> List[str]: +def tokenize( + text: str, max_ngram_length: int = 1, use_emoji: bool = True +) -> List[str]: """Convert a string to a list of lemmas.""" converted_text: Union[str, List[str]] = text.lower() - if has_emoji: + if has_emoji and use_emoji: parts = [] highest_end = 0 for emoji_part in emoji.emoji_list(text):