安装说明
安装Scikit-plot非常简单,直接用命令:
pip install scikit-plot
即可完成安装。
仓库地址:
https://github.com/reiinakano/scikit-plot
里面有使用说明和样例(py和ipynb格式)。
使用说明
简单举几个例子
-
比如画出分类评级指标的ROC曲线的完整代码:
from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.naive_bayes import GaussianNB X, y = load_digits(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33) nb = GaussianNB() nb.fit(X_train, y_train) predicted_probas = nb.predict_proba(X_test) # The magic happens here import matplotlib.pyplot as plt import scikitplot as skplt skplt.metrics.plot_roc(y_test, predicted_probas) plt.show()
效果如图
图:ROC曲线
-
P-R曲线就是精确率precision vs 召回率recall 曲线,以recall作为横坐标轴,precision作为纵坐标轴。首先解释一下精确率和召回率。
import matplotlib.pyplot as plt from sklearn.naive_bayes import GaussianNB from sklearn.datasets import load_digits as load_data import scikitplot as skplt # Load dataset X, y = load_data(return_X_y=True) # Create classifier instance then fit nb = GaussianNB() nb.fit(X,y) # Get predicted probabilities y_probas = nb.predict_proba(X) skplt.metrics.plot_precision_recall_curve(y, y_probas, cmap='nipy_spectral') plt.show()
-
混淆矩阵是分类的重要评价标准,下面代码是用随机森林对鸢尾花数据集进行分类,分类结果画一个归一化的混淆矩阵。
from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import load_digits as load_data from sklearn.model_selection import cross_val_predict import matplotlib.pyplot as plt import scikitplot as skplt X, y = load_data(return_X_y=True) # Create an instance of the RandomForestClassifier classifier = RandomForestClassifier() # Perform predictions predictions = cross_val_predict(classifier, X, y) plot = skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True) plt.show()
图:归一化混淆矩阵
-
其他图如学习曲线、特征重要性、聚类的肘点等等,都可以用几行代码搞定。
图:学习曲线、特征重要性
图:K-means肘点图