Compare commits
No commits in common. "b4766cb61344e48ed379d1c827c7aabd1b3f0c99" and "7ecb7dd90ac2cd5e9f3c63db0886eab768cebfee" have entirely different histories.
b4766cb613
...
7ecb7dd90a
|
@ -5,6 +5,7 @@
|
||||||
from gptc.compiler import compile as compile
|
from gptc.compiler import compile as compile
|
||||||
from gptc.classifier import Classifier as Classifier
|
from gptc.classifier import Classifier as Classifier
|
||||||
from gptc.pack import pack as pack
|
from gptc.pack import pack as pack
|
||||||
|
from gptc.tokenizer import has_emoji as has_emoji
|
||||||
from gptc.model import Model as Model, deserialize as deserialize
|
from gptc.model import Model as Model, deserialize as deserialize
|
||||||
from gptc.exceptions import (
|
from gptc.exceptions import (
|
||||||
GPTCError as GPTCError,
|
GPTCError as GPTCError,
|
||||||
|
|
|
@ -29,6 +29,7 @@ class Classifier:
|
||||||
self.model = model
|
self.model = model
|
||||||
model_ngrams = model.max_ngram_length
|
model_ngrams = model.max_ngram_length
|
||||||
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
self.max_ngram_length = min(max_ngram_length, model_ngrams)
|
||||||
|
self.has_emoji = gptc.tokenizer.has_emoji and model.has_emoji
|
||||||
|
|
||||||
def confidence(self, text: str) -> Dict[str, float]:
|
def confidence(self, text: str) -> Dict[str, float]:
|
||||||
"""Classify text with confidence.
|
"""Classify text with confidence.
|
||||||
|
@ -48,7 +49,9 @@ class Classifier:
|
||||||
|
|
||||||
model = self.model.weights
|
model = self.model.weights
|
||||||
|
|
||||||
tokens = gptc.tokenizer.tokenize(text, self.max_ngram_length)
|
tokens = gptc.tokenizer.tokenize(
|
||||||
|
text, self.max_ngram_length, self.has_emoji
|
||||||
|
)
|
||||||
numbered_probs: Dict[int, float] = {}
|
numbered_probs: Dict[int, float] = {}
|
||||||
for word in tokens:
|
for word in tokens:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -12,10 +12,14 @@ class Model:
|
||||||
weights: Dict[int, List[int]],
|
weights: Dict[int, List[int]],
|
||||||
names: List[str],
|
names: List[str],
|
||||||
max_ngram_length: int,
|
max_ngram_length: int,
|
||||||
|
has_emoji: Union[None, bool] = None,
|
||||||
):
|
):
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.names = names
|
self.names = names
|
||||||
self.max_ngram_length = max_ngram_length
|
self.max_ngram_length = max_ngram_length
|
||||||
|
self.has_emoji = (
|
||||||
|
gptc.tokenizer.has_emoji if has_emoji is None else has_emoji
|
||||||
|
)
|
||||||
|
|
||||||
def serialize(self) -> bytes:
|
def serialize(self) -> bytes:
|
||||||
out = b"GPTC model v4\n"
|
out = b"GPTC model v4\n"
|
||||||
|
@ -24,16 +28,7 @@ class Model:
|
||||||
{
|
{
|
||||||
"names": self.names,
|
"names": self.names,
|
||||||
"max_ngram_length": self.max_ngram_length,
|
"max_ngram_length": self.max_ngram_length,
|
||||||
"has_emoji": True,
|
"has_emoji": self.has_emoji,
|
||||||
# Due to an oversight in development, version 3.0.0 still
|
|
||||||
# had the code used to make emoji support optional, even
|
|
||||||
# though the `emoji` library was made a hard dependency.
|
|
||||||
# Part of this code checked whether or not the model
|
|
||||||
# supports emoji; deserialization would not work in 3.0.0
|
|
||||||
# if the model was compiled without this field. Emoji are
|
|
||||||
# always supported with 3.0.0 and newer when GPTC has been
|
|
||||||
# installed correctly, so this value should always be True.
|
|
||||||
# Related: #11
|
|
||||||
}
|
}
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
+ b"\n"
|
+ b"\n"
|
||||||
|
@ -62,11 +57,14 @@ def deserialize(encoded_model: bytes) -> Model:
|
||||||
try:
|
try:
|
||||||
names = config["names"]
|
names = config["names"]
|
||||||
max_ngram_length = config["max_ngram_length"]
|
max_ngram_length = config["max_ngram_length"]
|
||||||
|
has_emoji = config["has_emoji"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
isinstance(names, list) and isinstance(max_ngram_length, int)
|
isinstance(names, list)
|
||||||
|
and isinstance(max_ngram_length, int)
|
||||||
|
and isinstance(has_emoji, bool)
|
||||||
) or not all([isinstance(name, str) for name in names]):
|
) or not all([isinstance(name, str) for name in names]):
|
||||||
raise InvalidModelError()
|
raise InvalidModelError()
|
||||||
|
|
||||||
|
@ -88,4 +86,4 @@ def deserialize(encoded_model: bytes) -> Model:
|
||||||
for code in weight_codes
|
for code in weight_codes
|
||||||
}
|
}
|
||||||
|
|
||||||
return Model(weights, names, max_ngram_length)
|
return Model(weights, names, max_ngram_length, has_emoji)
|
||||||
|
|
|
@ -3,26 +3,38 @@
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
import emoji
|
|
||||||
|
try:
|
||||||
|
import emoji
|
||||||
|
|
||||||
|
has_emoji = True
|
||||||
|
except ImportError:
|
||||||
|
has_emoji = False
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text: str, max_ngram_length: int = 1) -> List[int]:
|
def tokenize(
|
||||||
text = text.lower()
|
text: str, max_ngram_length: int = 1, use_emoji: bool = True
|
||||||
parts = []
|
) -> List[int]:
|
||||||
highest_end = 0
|
"""Convert a string to a list of lemmas."""
|
||||||
for emoji_part in emoji.emoji_list(text):
|
converted_text: Union[str, List[str]] = text.lower()
|
||||||
parts += list(text[highest_end : emoji_part["match_start"]])
|
|
||||||
parts.append(emoji_part["emoji"])
|
if has_emoji and use_emoji:
|
||||||
highest_end = emoji_part["match_end"]
|
text = text.lower()
|
||||||
parts += list(text[highest_end:])
|
parts = []
|
||||||
converted_text = [part for part in parts if part]
|
highest_end = 0
|
||||||
|
for emoji_part in emoji.emoji_list(text):
|
||||||
|
parts += list(text[highest_end : emoji_part["match_start"]])
|
||||||
|
parts.append(emoji_part["emoji"])
|
||||||
|
highest_end = emoji_part["match_end"]
|
||||||
|
parts += list(text[highest_end:])
|
||||||
|
converted_text = [part for part in parts if part]
|
||||||
|
|
||||||
tokens = [""]
|
tokens = [""]
|
||||||
|
|
||||||
for char in converted_text:
|
for char in converted_text:
|
||||||
if char.isalpha() or char == "'":
|
if char.isalpha() or char == "'":
|
||||||
tokens[-1] += char
|
tokens[-1] += char
|
||||||
elif emoji.is_emoji(char):
|
elif has_emoji and emoji.is_emoji(char):
|
||||||
tokens.append(char)
|
tokens.append(char)
|
||||||
tokens.append("")
|
tokens.append("")
|
||||||
elif tokens[-1] != "":
|
elif tokens[-1] != "":
|
||||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "gptc"
|
name = "gptc"
|
||||||
version = "3.0.1"
|
version = "3.0.0"
|
||||||
description = "General-purpose text classifier"
|
description = "General-purpose text classifier"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}]
|
authors = [{ name = "Samuel Sloniker", email = "sam@kj7rrv.com"}]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user