Visual representation of a decision tree

In this example we are going to plot a decision tree of type DecisionTree.jl using the Bucheim Layout from NetworkLayout.jl.

using CairoMakie
using Graphs
using GraphMakie
import MLJ
using NetworkLayout
using DecisionTree

This following code, which walks the tree and creates a SimpleDiGraph was taken and slightly modified from syntaxtree.jl. Thanks!

The model is a DecisionTree object. maxdepth defines the max Depth of the final tree generated.

import Base.convert
function Base.convert(::Type{SimpleDiGraph},model::DecisionTree.DecisionTreeClassifier; maxdepth=depth(model))
    if maxdepth == -1
        maxdepth = depth(model)
    end
    g = SimpleDiGraph()
    properties = Any[]
    walk_tree!(model.root.node,g,maxdepth,properties)
    return g, properties
end

function walk_tree!(node::DecisionTree.Node, g, depthLeft, properties)
    add_vertex!(g)

    if depthLeft == 0
        push!(properties,(Nothing,"..."))
        return vertices(g)[end]
    else
        depthLeft -= 1
    end

    current_vertex = vertices(g)[end]
    val = node.featval

    featval = isa(val,AbstractString) ? val : round(val;sigdigits=2)
    push!(properties,(Node,"Feature $(node.featid) < $featval ?"))


    child = walk_tree!(node.left,g,depthLeft,properties)
    add_edge!(g,current_vertex,child)

    child = walk_tree!(node.right,g,depthLeft,properties)
    add_edge!(g,current_vertex,child)

    return current_vertex
end

function walk_tree!(leaf::DecisionTree.Leaf, g, depthLeft, properties)
    add_vertex!(g)
    n_matches = count(leaf.values .== leaf.majority)
    #ratio = string(n_matches, "/", length(leaf.values))

    push!(properties,(Leaf,"$(leaf.majority)"))# : $(ratio)"))
    return vertices(g)[end]
end

Ooof, quite a bit of code!

Makie @recipe

Now let's define a MakieRecipe for the plot, to make plotting easy

@recipe(PlotDecisionTree) do scene
    Attributes(
        nodecolormap = :darktest,
        textcolor = RGBf(0.5,0.5,0.5),
        leafcolor = :darkgreen,
        nodecolor = :white,
        maxdepth = -1,
    )
end

import GraphMakie.graphplot
import Makie.plot!
function GraphMakie.graphplot(model::DecisionTreeClassifier;kwargs...)
    f,ax,h = plotdecisiontree(model;kwargs...)
    hidedecorations!(ax); hidespines!(ax)
    return f
end

function plot!(plt::PlotDecisionTree{<:Tuple{DecisionTreeClassifier}})

    @extract plt leafcolor,textcolor,nodecolormap,nodecolor,maxdepth
    model = plt[1]

    # convert to graph
    tmpObs = @lift convert(SimpleDiGraph,$model;maxdepth=$maxdepth)
    graph = @lift $tmpObs[1]
    properties = @lift $tmpObs[2]

    # extract labels
    labels = @lift [string(p[2]) for p in $properties]

    # set the colors, first for nodes & cutoff-nodes, then for leaves
    nlabels_color = map(properties, labels, leafcolor,textcolor,nodecolormap) do properties,labels,leafcolor,textcolor,nodecolormap

        # set colors for the individual elements
        leaf_ix = findall([p[1] == Leaf for p in properties])
        leafValues = [p[1] for p in split.(labels[leaf_ix]," : ")]

        # one color per category
        uniqueLeafValues = unique(leafValues)
        individual_leaf_colors = resample_cmap(nodecolormap,length(uniqueLeafValues))
        nlabels_color = Any[p[1] == Node ? textcolor : leafcolor for p in properties]
        for (ix,uLV) = enumerate(uniqueLeafValues)
            ixV = leafValues .== uLV
            nlabels_color[leaf_ix[ixV]] .= individual_leaf_colors[ix]
        end
        return nlabels_color
    end

    # plot :)
    graphplot!(plt,graph;layout=Buchheim(),
               nlabels=labels,
               node_size = 100,
               node_color=nodecolor,
               nlabels_color=nlabels_color,
               nlabels_align=(:center,:center),
               ##tangents=((0,-1),(0,-1))
               )
    return plt
end

Visualizing a DecisionTree

Now finally we are ready to fit & visualize the tree

iris = MLJ.load_iris()
features = hcat(iris[1],iris[2],iris[3],iris[4])
labels = iris[5]
model = DecisionTreeClassifier(max_depth=4)
fit!(model, features, labels)
DecisionTreeClassifier
max_depth:                4
min_samples_leaf:         1
min_samples_split:        2
min_purity_increase:      0.0
pruning_purity_threshold: 1.0
n_subfeatures:            0
classes:                  ["setosa", "versicolor", "virginica"]
root:                     Decision Tree
Leaves: 8
Depth:  4

now we are ready to plot

graphplot(model)
Example block output

you can also specify depth, or modify colors

graphplot(model;maxdepth=3,textcolor=:darkgreen)
Example block output

This page was generated using Literate.jl.