📅  最后修改于: 2023-12-03 15:28:49.224000             🧑  作者: Mango
本文主要介绍《门|门CS 2013》中的问题4,涉及到机器学习中的朴素贝叶斯分类器,要求实现一个简单的文本分类器,使用朴素贝叶斯算法计算文档属于不同类别的概率。
给定n个文档和它们所属的类别,现在需要实现一个简单的文本分类器,使用朴素贝叶斯算法计算一个新文档属于每个类别的概率,并返回分类结果。
首先需要统计每个类别中每个单词的频率,以及每个类别中所有单词的总数,这个过程可以通过遍历所有文档,使用Python的Counter类实现。
然后计算每个单词在不同类别中出现的概率,可以使用贝叶斯公式计算,其中要特别注意对概率进行平滑,避免因为某个单词在某个类别中没有出现而导致概率为0的情况。
对于每个新文档,计算它属于每个类别的概率,选择概率最大的类别作为该文档的分类结果,这个过程可以通过遍历文档中所有单词,使用之前计算出的每个单词在不同类别中出现的概率和类别的先验概率(即已有的文档中每个类别所占比例)计算得出。
返回结果格式使用markdown,包括以下内容:
问题描述:使用朴素贝叶斯算法实现文本分类器,输入为n个文档和它们所属的类别,输出为一个新文档属于每个类别的概率和分类结果。
输入格式:n个文档,每个文档为一个字符串,类别为一个整数,输入格式为列表[(doc1, label1), (doc2, label2), ..., (docn, labeln)]。
输出格式:输出为字典,包含以下两个键值对:
probabilities:一个字典,包括该文档属于每个类别的概率。
predict:一个整数,代表该文档的分类结果。
代码片段:实现分类器的Python代码片段,按markdown格式标明。
from collections import Counter
import math
class NaiveBayesClassifier:
def __init__(self):
self.word_counts = {} # 每个类别中每个单词的出现次数
self.total_counts = {} # 每个类别中所有单词的总数
self.categories = set() # 所有类别
self.doc_counts = Counter() # 每个类别中文档的数量
self.total_docs = 0 # 所有文档的数量
def fit(self, X, y):
# 统计词频和文档数量
for doc, cat in zip(X, y):
if cat not in self.word_counts:
self.word_counts[cat] = Counter()
self.total_counts[cat] = 0
self.doc_counts[cat] = 0
for word in doc.split():
self.word_counts[cat][word] += 1
self.total_counts[cat] += 1
self.doc_counts[cat] += 1
self.categories.add(cat)
self.total_docs += 1
# 计算每个单词在每个类别中出现的概率
self.word_probs = {}
for cat in self.categories:
self.word_probs[cat] = {}
for word in self.word_counts[cat]:
self.word_probs[cat][word] = (self.word_counts[cat][word] + 1) / (self.total_counts[cat] + len(self.word_counts[cat]))
def predict(self, X):
probs = {}
max_prob = -math.inf
argmax_cat = None
# 计算该文档属于每个类别的概率
for cat in self.categories:
prior_prob = self.doc_counts[cat] / self.total_docs
probs[cat] = prior_prob
for word in X.split():
if word in self.word_probs[cat]:
probs[cat] *= self.word_probs[cat][word]
else:
probs[cat] *= 1 / (self.total_counts[cat] + len(self.word_counts[cat]))
# 选择概率最大的类别作为分类结果
if probs[cat] > max_prob:
max_prob = probs[cat]
argmax_cat = cat
return {"probabilities": probs, "predict": argmax_cat}
其中,X
为输入文档列表,每个文档为字符串,y
为对应的标签,为整数。fit
方法用于训练分类器,predict
方法用于对新文档进行分类。使用时,先实例化一个NaiveBayesClassifier
对象,然后调用其fit
方法传入训练数据,最后调用predict
方法传入新文档进行分类。