from dtreeviz.trees import dtreeviz
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=42, max_depth=6)
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target)
clf.fit(X_train, y_train)
clf.score(X_test, y_test)
0.9473684210526315
viz = dtreeviz(clf, X_train, y_train, target_name='variety',
feature_names=iris.feature_names,
class_names=list(iris.target_names),
)
viz.view()