Decision Tree in Python, with Graphviz to Visualize

Following the last article, we can also use decision tree to evaluate the relationship of breast cancer and all the features within the data. Most of the code comes from the as book of last article. Thanks to the authors: Andreas C. Mueller and Sarah Guido

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
cancer=load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
	cancer.data, cancer.target, stratify=cancer.target, random_state=2017)
tree=DecisionTreeClassifier(random_state=0)
###Decision trees in scikit-learn are implemented in the DecisionTreeRegressor 
##and DecisionTreeClassifier classes. Scikit-learn only implements pre-pruning, not post- pruning.
tree.fit(X_train,y_train)
print("accuracy on training set: %f" % tree.score(X_train, y_train))
print('\n'"accuracy on test set: %f" % tree.score(X_test, y_test))

###accuracy on training set: 1.000000###

###accuracy on test set: 0.916084###

###apply pre-pruning to the tree, which will stop developing the tree before we
### perfectly fit to the training data.
tree01=DecisionTreeClassifier(max_depth=4,random_state=0)
tree01.fit(X_train,y_train)
print('\n'"accuracy on training set 01: %f" % tree01.score(X_train, y_train))
print('\n'"accuracy on test set 01: %f" % tree01.score(X_test, y_test))

###accuracy on training set 01: 0.990610###

###accuracy on test set 01: 0.937063###


###visualize and analyze the tree model###
###build a file to visualize 
from sklearn.tree import export_graphviz
export_graphviz(tree,out_file="mytree.dot",class_names=['malignant',"benign"],
	feature_names=cancer.feature_names,impurity=False,filled=True)
###visualize the .dot file. Need to install graphviz seperately at first 
import graphviz
with open("mytree.dot") as f:
	dot_graph=f.read()
graphviz.Source(dot_graph)

Screen Shot 2017-05-20 at 5.58.52 PM.png

###We also can derive to summarize the workings of the tree. 
###The most commonly used summary is feature importance
print("\n",tree01.feature_importances_)
import matplotlib.pyplot as plt 
plt.plot(tree01.feature_importances_,'o')
plt.xticks(range(cancer.data.shape[1]),cancer.feature_names,rotation=90)
plt.ylim(0,1)
plt.show()

###Here, we see that the feature used at the top split (“worst radius”) is by far the most important feature.

figure_1.png

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s