Decision tree using rpart to produce a sankey diagram

Here is my attempt:

From what i see the challenge is to generate nodesand source variables.

Sample data:

fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)

Generate nodes:

frame <- fit$frame
isLeave <- frame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[frame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

Generate source:

node <- as.numeric(row.names(frame))
depth <- rpart:::tree.depth(node)
source <- depth[-1] - 1

reps <- rle(source)
tobeAdded <- reps$values[sapply(reps$values, function(val) sum(val >= which(reps$lengths > 1))) > 0]
update <- source %in% tobeAdded
source[update] <- source[update] + sapply(tobeAdded, function(tobeAdd) rep(sum(which(reps$lengths > 1) <= tobeAdd), 2))

Tested with:

library(rpart)
fit <- rpart(Kyphosis ~ Age + Number + Start,
             method="class", data=kyphosis)
fit2 <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis,
              parms = list(prior = c(.65,.35), split = "information"))

How to get there:

See: getS3method("print", "rpart")


I have a temporary solution for the time being. I just don't like loading in a extra library. But here it is: Fitting the model for the Iris dataset:

fit <- rpart(Species~Sepal.Length +Sepal.Width   ,
         method="class", data=iris)

printcp(fit)
plot(fit, uniform=TRUE, 
     main="Classification Tree for IRIS")
text(fit, use.n=TRUE, all=TRUE, cex=.8)

The way I use to get the node names was with:

treeFrame=fit$frame
nodes=sapply(row.names(treeFrame),function(x) unlist(rpart::path.rpart(fit,x))
        [length(unlist(rpart::path.rpart(fit,x)))])

But in @BigDataScientist solution there is a better way:

treeFrame=fit$frame
isLeave <- treeFrame$var == "<leaf>"
nodes <- rep(NA, length(isLeave))
ylevel <- attr(fit, "ylevels")
nodes[isLeave] <- ylevel[treeFrame$yval][isLeave]
nodes[!isLeave] <- labels(fit)[-1][!isLeave[-length(isLeave)]]

Now to get the source and target is still a bit tricky but what helped me was rpart.utils package:

library('rpart.utils')
treeFrame=fit$frame
treeRules=rpart.utils::rpart.rules(fit)

targetPaths=sapply(as.numeric(row.names(treeFrame)),function(x)  
                      strsplit(unlist(treeRules[x]),split=","))

lastStop=  sapply(1:length(targetPaths),function(x) targetPaths[[x]] 
                      [length(targetPaths[[x]])])

oneBefore=  sapply(1:length(targetPaths),function(x) targetPaths[[x]] 
                      [length(targetPaths[[x]])-1])


target=c()
source=c()
values=treeFrame$n
for(i in 2:length(oneBefore))
{
  tmpNode=oneBefore[[i]]
  q=which(lastStop==tmpNode)

  q=ifelse(length(q)==0,1,q)
  source=c(source,q)
  target=c(target,i)

}
source=source-1
target=target-1

So I don't like using an extra library but this seems to work for various data sets. And using the way @BigDataScientist gets nodes is better. But I will still look out for better solutions. @BigDataScientist I think your solution will work better maybe something small needs to change. But I don't understand the "reps" part of your code that well yet.

And the code for the plot in the end is:

 p <- plot_ly(
 type = "sankey",
 orientation = "v",

 node = list(
     label = nodes,
     pad = 15,
     thickness = 20,
     line = list(
     color = "black",
     width = 0.5
     )
 ),

 link = list(
     source = source,
     target = target,
     value=values[-1]

 )
 ) %>% 
 layout(
     title = "Basic Sankey Diagram",
     font = list(
     size = 10
     )
 )
 p