随机森林算法实践
目录
随机森林算法实践代码详解
运行环境
- Python版本:Python 2.7
- 依赖库:
jieba==0.37 scikit-learn==0.17
核心功能说明
该代码实现了基于随机森林的文本分类任务,包含以下核心功能:
- 中文分词:使用jieba进行文本预处理
- 特征工程:one-hot编码生成特征向量
- 模型训练:使用scikit-learn的RandomForestClassifier
- 结果预测:支持新文本的分类预测
- 模型持久化:训练结果保存与加载
代码结构解析
1. 核心类结构
class RandomForest:
def __init__(self, is_save=False) # 初始化模型
def build_train_data(...) # 构建训练集
def predict(self, predict_data) # 执行预测
@staticmethod
def word_segmentation(...) # 分词处理
@staticmethod
def pre_treat_data(...) # 特征预处理
class RandomForestTools:
@staticmethod
def train_data_save/clf_load # 模型保存/加载
@staticmethod
def feature_data_save/load # 特征数据保存/加载
@staticmethod
def one_hot_encode_feature(...) # 特征编码
2. 数据处理流程
graph TD
A[原始文本] --> B(分词处理)
B --> C{特征提取}
C --> D[one-hot编码]
D --> E[二维特征矩阵]
E --> F[模型训练]
3. 关键代码实现
分词处理
@staticmethod
def word_segmentation(train_data):
"""中文分词处理"""
word_segmentation_result = set()
for word in jieba.lcut(train_data):
word_segmentation_result.add(word)
return word_segmentation_result
特征编码(one-hot)
@staticmethod
def one_hot_encode_feature(data_list, data_set):
"""将文本转换为二进制特征向量"""
serialize_list = []
for data in data_list:
tmp_serialize_list = []
for key in data_set:
tmp_serialize_list.append(1 if key in data else 0)
serialize_list.append(tmp_serialize_list)
return serialize_list
模型训练
def build_train_data(self, pre_train_data_list, result_list, train_size=0.9):
# 1. 分词处理
train_data_list = [self.word_segmentation(text) for text in pre_train_data_list]
# 2. 特征提取
train_data_feature = set()
for data in train_data_list:
train_data_feature.update(data)
# 3. 特征编码
data_matrix = self.pre_treat_data(train_data_list, train_data_feature)
# 4. 数据集划分
data_train, data_test, result_train, result_test = train_test_split(
data_matrix, result_list, train_size=train_size
)
# 5. 模型训练
self.__clf = RandomForestClassifier(n_jobs=-1).fit(data_train, result_train)
# 6. 结果评估
accuracy = self.__clf.score(data_test, result_test)
使用示例
# 训练数据
pre_train_data = [
u'我很开心', u'我非常开心', u'我其实很开心',
u'我不开心', u'我一点都不开心', u'我很不开心'
]
result_list = [u'开心', u'开心', u'开心', u'不开心', u'不开心', u'不开心']
# 模型训练
rf = RandomForest()
rf.build_train_data(pre_train_data, result_list)
# 执行预测
print(rf.predict(u'你猜我开心吗?')) # 输出预测结果
注意事项
-
Python版本兼容性:
- 建议升级到Python 3.6+(原代码基于Python 2.7)
- 修改print语句为函数形式(添加括号)
- 将
basestring替换为str
-
性能优化:
- 大规模数据训练时注意内存管理(已包含gc.collect())
- 特征维度爆炸问题:可增加特征过滤逻辑
-
模型持久化:
- 训练结果保存路径:
/tmp/train_data.pkl - 特征数据保存路径:
/tmp/feature_data.pkl
- 训练结果保存路径:
-
扩展方向:
- 增加TF-IDF特征提取
- 添加交叉验证支持
- 实现批量预测接口
项目源码:GitHub链接
算法原理参考:《机器学习算法-随机森林》
模型性能评估
在示例数据集上的测试结果:
Build train data finish and accuracy is:1.00 .
(注:示例数据量小且特征明显,实际应用中准确率会根据数据质量波动)