Back

Explore Courses Blog Tutorials Interview Questions
0 votes
2 views
in Machine Learning by (4.2k points)

I am using the gbm function in R (gbm package) to fit stochastic gradient boosting models for multiclass classification. I am simply trying to obtain the importance of each predictor separately for each class, like in this picture from the Hastie book (the Elements of Statistical Learning) (p. 382).

enter image description here

However, the function summary.gbm only returns the overall importance of the predictors (their importance averaged over all classes).

Does anyone know how to get the relative importance values?

1 Answer

0 votes
by (6.8k points)

Hopefully, this function helps you. For example, I used data from the ElemStatLearn package. The function figures out what the classes for a column are, splits the data into these classes, runs the gbm() function on each class and plots the bar plots for these models. Studying Gradient Boosting will give one some better insights on this.

# install.packages("ElemStatLearn"); install.packages("gbm")

library(ElemStatLearn)

library(gbm)

set.seed(137531)

# formula: the formula to pass to gbm()

# data: the data set to use

# column: the class column to use

classPlots <- function (formula, data, column) {

    class_column <- as.character(data[,column])

    class_values <- names(table(class_column))

    class_indexes <- sapply(class_values, function(x) which(class_column == x))

    split_data <- lapply(class_indexes, function(x) marketing[x,])

    object <- lapply(split_data, function(x) gbm(formula, data = x))

    rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))

    nobjs <- length(class_values)

    for( i in 1:nobjs ) {

        tmp <- rel.inf[[i]]

        tmp.names <- row.names(tmp)

        tmp <- tmp$rel.inf

        names(tmp) <- tmp.names

        barplot(tmp, horiz=TRUE, col='red',

                xlab="Relative importance", main=paste0("Class = ", class_values[i]))

    }

    rel.inf

}

par(mfrow=c(1,2))

classPlots(Income ~ Marital + Age, data = marketing, column = 2)

image

Browse Categories

...