Initial commit of V2
V2 is a major upgrade that (hopefully) makes it more scalable by moving the majority of the work to compile time instead of runtime.
This commit is contained in:
parent
72f380b0dd
commit
384167dcdb
|
@ -9,68 +9,61 @@ def listify(text):
|
||||||
|
|
||||||
|
|
||||||
def compile(raw_model):
|
def compile(raw_model):
|
||||||
model = {}
|
categories = {}
|
||||||
|
|
||||||
for portion in raw_model:
|
for portion in raw_model:
|
||||||
text = listify(portion['text'])
|
text = listify(portion['text'])
|
||||||
category = portion['category']
|
category = portion['category']
|
||||||
|
try:
|
||||||
|
categories[category] += text
|
||||||
|
except KeyError:
|
||||||
|
categories[category] = text
|
||||||
|
|
||||||
|
categories_by_count = {}
|
||||||
|
|
||||||
|
for category, text in categories.items():
|
||||||
|
categories_by_count[category] = {}
|
||||||
for word in text:
|
for word in text:
|
||||||
try:
|
try:
|
||||||
model[category].append(word)
|
categories_by_count[category][word] += 1/len(categories[category])
|
||||||
except:
|
except KeyError:
|
||||||
model[category] = [word]
|
categories_by_count[category][word] = 1/len(categories[category])
|
||||||
model[category].sort()
|
word_weights = {}
|
||||||
all_models = [ { 'text': model, 'stopword': i/10} for i in range(0, 21) ]
|
for category, words in categories_by_count.items():
|
||||||
for test_model in all_models:
|
for word, value in words.items():
|
||||||
correct = 0
|
try:
|
||||||
classifier = Classifier(test_model)
|
word_weights[word][category] = value
|
||||||
for text in raw_model:
|
except KeyError:
|
||||||
if classifier.check(text['text']) == text['category']:
|
word_weights[word] = {category:value}
|
||||||
correct += 1
|
|
||||||
test_model['correct'] = correct
|
return word_weights
|
||||||
print('tested a model')
|
|
||||||
best = all_models[0]
|
|
||||||
for test_model in all_models:
|
|
||||||
if test_model['correct'] > best['correct']:
|
|
||||||
best = test_model
|
|
||||||
del best['correct']
|
|
||||||
return best
|
|
||||||
return {'text': model}
|
|
||||||
|
|
||||||
|
|
||||||
class Classifier:
|
class Classifier:
|
||||||
def __init__(self, model, supress_uncompiled_model_warning=False):
|
def __init__(self, model, supress_uncompiled_model_warning=False):
|
||||||
if type(model['text']) == dict:
|
if type(model) == dict:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
self.model = compile(model)
|
self.model = compile(model)
|
||||||
if not supress_uncompiled_model_warning:
|
if not supress_uncompiled_model_warning:
|
||||||
print('WARNING: model was not compiled', file=sys.stderr)
|
print('WARNING: model was not compiled', file=sys.stderr)
|
||||||
print('In development, this is OK, but precompiling the model is preferred for production use.', file=sys.stderr)
|
print('This makes everything slow, because compiling models takes far longer than using them.', file=sys.stderr)
|
||||||
self.warn = supress_uncompiled_model_warning
|
self.warn = supress_uncompiled_model_warning
|
||||||
|
|
||||||
def check(self, text):
|
def check(self, text):
|
||||||
model = self.model
|
model = self.model
|
||||||
stopword_value = 0.5
|
|
||||||
try:
|
|
||||||
stopword_value = model['stopword']
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
stopwords = spacy.lang.en.stop_words.STOP_WORDS
|
|
||||||
model = model['text']
|
|
||||||
text = listify(text)
|
text = listify(text)
|
||||||
probs = {}
|
probs = {}
|
||||||
for word in text:
|
for word in text:
|
||||||
for category in model.keys():
|
try:
|
||||||
for catword in model[category]:
|
for category, value in model[word].items():
|
||||||
if word == catword:
|
try:
|
||||||
weight = ( stopword_value if word in stopwords else 1 ) / len(model[category])
|
probs[category] += value
|
||||||
try:
|
except KeyError:
|
||||||
probs[category] += weight
|
probs[category] = value
|
||||||
except:
|
except KeyError:
|
||||||
probs[category] = weight
|
pass
|
||||||
most_likely = ['unknown', 0]
|
try:
|
||||||
for category in probs.keys():
|
return sorted(probs.items(), key=lambda x: x[1])[-1][0]
|
||||||
if probs[category] > most_likely[1]:
|
except IndexError:
|
||||||
most_likely = [category, probs[category]]
|
return None
|
||||||
return most_likely[0]
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
|
parser = argparse.ArgumentParser(description="General Purpose Text Classifier")
|
||||||
parser.add_argument('model', help='model to use')
|
parser.add_argument('model', help='model to use')
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user