1 # -*- coding: utf-8 -*- 2 ''' 3 >>> c = Classy() 4 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture') 5 True 6 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices') 7 True 8 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture') 9 True 10 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair'] 11 >>> c.classify(my_office) 12 ('input_devices', -1.0986122886681098) 13 ... 14 >>> c = Classy() 15 >>> c.train(['cpu', 'RAM', 'ALU', 'io', 'bridge', 'disk'], 'architecture') 16 True 17 >>> c.train(['monitor', 'mouse', 'keyboard', 'microphone', 'headphones'], 'input_devices') 18 True 19 >>> c.train(['desk', 'chair', 'cabinet', 'lamp'], 'office furniture') 20 True 21 >>> my_office = ['cpu', 'monitor', 'mouse', 'chair'] 22 >>> c.classify(my_office) 23 ('input_devices', -1.0986122886681098) 24 ... 25 ''' 26 27 from collections import Counter 28 import math 29 30 class ClassifierNotTrainedException(Exception): 31 32 def __str__(self): 33 return "Classifier is not trained." 34 35 class Classy(object): 36 37 def __init__(self): 38 self.term_count_store = {} 39 self.data = { 40 'class_term_count': {}, 41 'beta_priors': {}, 42 'class_doc_count': {}, 43 } 44 self.total_term_count = 0 45 self.total_doc_count = 0 46 47 def train(self, document_source, class_id): 48 49 ''' 50 Trains the classifier. 51 52 ''' 53 count = Counter(document_source) 54 try: 55 self.term_count_store[class_id] 56 except KeyError: 57 self.term_count_store[class_id] = {} 58 for term in count: 59 try: 60 self.term_count_store[class_id][term] += count[term] 61 except KeyError: 62 self.term_count_store[class_id][term] = count[term] 63 try: 64 self.data['class_term_count'][class_id] += document_source.__len__() 65 except KeyError: 66 self.data['class_term_count'][class_id] = document_source.__len__() 67 try: 68 self.data['class_doc_count'][class_id] += 1 69 except KeyError: 70 self.data['class_doc_count'][class_id] = 1 71 self.total_term_count += document_source.__len__() 72 self.total_doc_count += 1 73 self.compute_beta_priors() 74 return True 75 76 def classify(self, document_input): 77 if not self.total_doc_count: raise ClassifierNotTrainedException() 78 79 term_freq_matrix = Counter(document_input) 80 arg_max_matrix = [] 81 for class_id in self.data['class_doc_count']: 82 summation = 0 83 for term in document_input: 84 try: 85 conditional_probability = (self.term_count_store[class_id][term] + 1) 86 conditional_probability = conditional_probability / (self.data['class_term_count'][class_id] + self.total_doc_count) 87 summation += term_freq_matrix[term] * math.log(conditional_probability) 88 except KeyError: 89 break 90 arg_max = summation + self.data['beta_priors'][class_id] 91 arg_max_matrix.insert(0, (class_id, arg_max)) 92 arg_max_matrix.sort(key=lambda x:x[1]) 93 return (arg_max_matrix[-1][0], arg_max_matrix[-1][1]) 94 95 def compute_beta_priors(self): 96 if not self.total_doc_count: raise ClassifierNotTrainedException() 97 98 for class_id in self.data['class_doc_count']: 99 tmp = self.data['class_doc_count'][class_id] / self.total_doc_count 100 self.data['beta_priors'][class_id] = math.log(tmp)
本文转自罗兵博客园博客,原文链接:http://www.cnblogs.com/hhh5460/p/4319427.html,如需转载请自行联系原作者