Switch __init__ to import
This commit is contained in:
parent
aceeef846f
commit
63e26e97de
70
gptc/__init__.py
Executable file → Normal file
70
gptc/__init__.py
Executable file → Normal file
|
@ -1,69 +1 @@
|
||||||
#!/usr/bin/env python3
|
from gptc.gptc import compile, Classifier
|
||||||
import sys
|
|
||||||
import spacy
|
|
||||||
|
|
||||||
nlp = spacy.load('en_core_web_sm')
|
|
||||||
|
|
||||||
def listify(text):
|
|
||||||
return [string.lemma_.lower() for string in nlp(text) if string.lemma_[0] in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
|
|
||||||
|
|
||||||
|
|
||||||
def compile(raw_model):
|
|
||||||
categories = {}
|
|
||||||
|
|
||||||
for portion in raw_model:
|
|
||||||
text = listify(portion['text'])
|
|
||||||
category = portion['category']
|
|
||||||
try:
|
|
||||||
categories[category] += text
|
|
||||||
except KeyError:
|
|
||||||
categories[category] = text
|
|
||||||
|
|
||||||
categories_by_count = {}
|
|
||||||
|
|
||||||
for category, text in categories.items():
|
|
||||||
categories_by_count[category] = {}
|
|
||||||
for word in text:
|
|
||||||
try:
|
|
||||||
categories_by_count[category][word] += 1/len(categories[category])
|
|
||||||
except KeyError:
|
|
||||||
categories_by_count[category][word] = 1/len(categories[category])
|
|
||||||
word_weights = {}
|
|
||||||
for category, words in categories_by_count.items():
|
|
||||||
for word, value in words.items():
|
|
||||||
try:
|
|
||||||
word_weights[word][category] = value
|
|
||||||
except KeyError:
|
|
||||||
word_weights[word] = {category:value}
|
|
||||||
|
|
||||||
return word_weights
|
|
||||||
|
|
||||||
|
|
||||||
class Classifier:
|
|
||||||
def __init__(self, model, supress_uncompiled_model_warning=False):
|
|
||||||
if type(model) == dict:
|
|
||||||
self.model = model
|
|
||||||
else:
|
|
||||||
self.model = compile(model)
|
|
||||||
if not supress_uncompiled_model_warning:
|
|
||||||
print('WARNING: model was not compiled', file=sys.stderr)
|
|
||||||
print('This makes everything slow, because compiling models takes far longer than using them.', file=sys.stderr)
|
|
||||||
self.warn = supress_uncompiled_model_warning
|
|
||||||
|
|
||||||
def classify(self, text):
|
|
||||||
model = self.model
|
|
||||||
text = listify(text)
|
|
||||||
probs = {}
|
|
||||||
for word in text:
|
|
||||||
try:
|
|
||||||
for category, value in model[word].items():
|
|
||||||
try:
|
|
||||||
probs[category] += value
|
|
||||||
except KeyError:
|
|
||||||
probs[category] = value
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|
113
gptc/gptc.py
Executable file
113
gptc/gptc.py
Executable file
|
@ -0,0 +1,113 @@
|
||||||
|
'''Main module for GPTC.'''
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import spacy
|
||||||
|
|
||||||
|
nlp = spacy.load('en_core_web_sm')
|
||||||
|
|
||||||
|
def _listify(text):
|
||||||
|
"""Convert a string to a list of lemmas. Internal use only."""
|
||||||
|
return [string.lemma_.lower() for string in nlp(text) if string.lemma_[0] in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
|
||||||
|
|
||||||
|
|
||||||
|
def compile(raw_model):
|
||||||
|
"""Compile a raw model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_model : list of dict
|
||||||
|
A raw GPTC model.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A compiled GPTC model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
categories = {}
|
||||||
|
|
||||||
|
for portion in raw_model:
|
||||||
|
text = _listify(portion['text'])
|
||||||
|
category = portion['category']
|
||||||
|
try:
|
||||||
|
categories[category] += text
|
||||||
|
except KeyError:
|
||||||
|
categories[category] = text
|
||||||
|
|
||||||
|
categories_by_count = {}
|
||||||
|
|
||||||
|
for category, text in categories.items():
|
||||||
|
categories_by_count[category] = {}
|
||||||
|
for word in text:
|
||||||
|
try:
|
||||||
|
categories_by_count[category][word] += 1/len(categories[category])
|
||||||
|
except KeyError:
|
||||||
|
categories_by_count[category][word] = 1/len(categories[category])
|
||||||
|
word_weights = {}
|
||||||
|
for category, words in categories_by_count.items():
|
||||||
|
for word, value in words.items():
|
||||||
|
try:
|
||||||
|
word_weights[word][category] = value
|
||||||
|
except KeyError:
|
||||||
|
word_weights[word] = {category:value}
|
||||||
|
|
||||||
|
return word_weights
|
||||||
|
|
||||||
|
|
||||||
|
class Classifier:
|
||||||
|
"""A text classifier.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model : dict or list
|
||||||
|
A compiled or raw GPTC model. Please don't use raw models here.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
model : dict
|
||||||
|
The model used. This is always a compiled model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, supress_uncompiled_model_warning=False):
|
||||||
|
if type(model) == dict:
|
||||||
|
self.model = model
|
||||||
|
else:
|
||||||
|
self.model = compile(model)
|
||||||
|
if not supress_uncompiled_model_warning:
|
||||||
|
print('WARNING: model was not compiled', file=sys.stderr)
|
||||||
|
print('This makes everything slow, because compiling models takes far longer than using them.', file=sys.stderr)
|
||||||
|
self.warn = supress_uncompiled_model_warning
|
||||||
|
|
||||||
|
def classify(self, text):
|
||||||
|
"""Classify text.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
text : str
|
||||||
|
The text to classify
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str or None
|
||||||
|
The most likely category, or None if no guess was made.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = self.model
|
||||||
|
text = _listify(text)
|
||||||
|
probs = {}
|
||||||
|
for word in text:
|
||||||
|
try:
|
||||||
|
for category, value in model[word].items():
|
||||||
|
try:
|
||||||
|
probs[category] += value
|
||||||
|
except KeyError:
|
||||||
|
probs[category] = value
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
||||||
|
except IndexError:
|
||||||
|
return None
|
Loading…
Reference in New Issue
Block a user