Git Product home page Git Product logo

triplot's Introduction

triplot

CRAN_Status_Badge R build status Codecov test coverage DrWhy-eXtrAI

Introduction

The triplot package provides tools for exploration of machine learning predictive models. It contains an instance-level explainer called predict_aspects (AKA aspects_importance), that is able to explain the contribution of the whole groups of explanatory variables. Furthermore, package delivers functionality called triplot - it illustrates how the importance of aspects (group of predictors) change depending on the size of aspects.

Key functions:

  • predict_triplot() and model_triplot() for instance- and data-level summary of automatic aspect importance grouping,
  • predict_aspects() for calculating the feature groups importance (called aspects importance) for a selected observation,
  • group_variables() for grouping of correlated numeric features into aspects.

The triplot package is a part of DrWhy.AI universe. More information about analysis of machine learning models can be found in the Explanatory Model Analysis. Explore, Explain and Examine Predictive Models e-book.

Installation

# from CRAN:
install.packages("triplot")
# from GitHub (development version):
# install.packages("devtools")
devtools::install_github("ModelOriented/triplot")

Overview

triplot shows, in one place:

  • the importance of every single feature,
  • hierarchical aspects importance,
  • order of grouping features into aspects.

We can use it to investigate the instance level importance of features (using predict_aspects() function) or to illustrate the model level importance of features (using model_parts() function from DALEX package). triplot can be only used on numerical features. More information about this functionality can be found in triplot overview.

Basic triplot for a model

To showcase triplot, we will choose apartments dataset from DALEX, use it’s numeric features to build a model, create DALEX explainer, use model_triplot() to calculate the triplot object and then plot it with the generic plot() function.

Import apartments and train a linear model

library("DALEX")
apartments_num <- apartments[,unlist(lapply(apartments, is.numeric))]

model_apartments <- lm(m2.price ~ ., data = apartments_num)

Create an explainer

explain_apartments <- DALEX::explain(model = model_apartments, 
                              data = apartments_num[, -1],
                              y = apartments_num$m2.price,
                              verbose = FALSE)

Create a triplot object

set.seed(123)
library("triplot")

tri_apartments <- model_triplot(explain_apartments)

plot(tri_apartments) + 
  patchwork::plot_annotation(title = "Global triplot for four variables in the linear model")
The left panel shows the global importance of individual variables. Right panel shows global correlation structure visualized by hierarchical clustering. The middle panel shows the importance of groups of variables determined by the hierarchical clustering.

The left panel shows the global importance of individual variables. Right panel shows global correlation structure visualized by hierarchical clustering. The middle panel shows the importance of groups of variables determined by the hierarchical clustering.

At the model level, surface and floor have the biggest contributions. But we also know that Number of rooms and surface are strongly correlated and together have strong influence on the model prediction.Construction year has small influence on the prediction, is not correlated with number of rooms nor surface variables. Adding construction year to them, only slightly increases the importance of this group.

Basic triplot for an observation

Afterwards, we are building triplot for single instance and it’s prediction.

(new_apartment <- apartments_num[6, -1])
##   construction.year surface floor no.rooms
## 6              1926      61     6        2
tri_apartments <- predict_triplot(explain_apartments, 
                                  new_observation = new_apartment)

plot(tri_apartments) + 
  patchwork::plot_annotation(title = "Local triplot for four variables in the linear model")
The left panel shows the local importance of individual variables (similar to LIME). Right panel shows global correlation structure visualized by hierarchical clustering. The middle panel shows the local importance of groups of variables (similar to LIME) determined by the hierarchical clustering.

The left panel shows the local importance of individual variables (similar to LIME). Right panel shows global correlation structure visualized by hierarchical clustering. The middle panel shows the local importance of groups of variables (similar to LIME) determined by the hierarchical clustering.

We can observe that for the given apartment surface has also significant, positive influence on the prediction. Adding number of rooms, increases its contribution. However, adding construction year to those two features, decreases the group importance.

We can notice that floor has the small influence on the prediction of this observation, unlike in the model-level analysis.

Aspect importance for single instance

For this example we use titanic dataset with a logistic regression model that predicts passenger survival. Features are combined into thematic aspects.

Importing dataset and building a logistic regression model

set.seed(123)

model_titanic_glm <- glm(survived ~ ., titanic_imputed, family = "binomial")

Manual selection of aspects

aspects_titanic <-
  list(
    wealth = c("class", "fare"),
    family = c("sibsp", "parch"),
    personal = c("age", "gender"),
    embarked = "embarked"
  )

Select an instance

We are interested in explaining the model prediction for the johny_d example.

(johny_d <- titanic_imputed[2,])
##   gender age class    embarked  fare sibsp parch survived
## 2   male  13   3rd Southampton 20.05     0     2        0
predict(model_titanic_glm, johny_d, type = "response")
##         2 
## 0.1531932

It turns out that the model prediction for this passenger’s survival is very low. Let’s see which aspects have the biggest influence on it.

We start with DALEX explainer.

explain_titanic <- DALEX::explain(model_titanic_glm, 
                           data = titanic_imputed,
                           y = titanic_imputed$survived,
                           label = "Logistic Regression",
                           verbose = FALSE)

And use it to call triplot::predict_aspects() function. Afterwards, we print and plot function results

library("triplot")

ai_titanic <- predict_aspects(x = explain_titanic, 
                              new_observation = johny_d[,-8],
                              variable_groups = aspects_titanic)

print(ai_titanic, show_features = TRUE)
##   variable_groups importance     features
## 2          wealth  -0.122049  class, fare
## 3          family   0.023564 sibsp, parch
## 5        embarked  -0.007929     embarked
## 4        personal   0.004069  age, gender
plot(ai_titanic)

We can observe that wealth (class, fare) variables have the biggest contribution to the prediction. This contribution is of a negative type. Personal (age, gender) and Family (sibsp, parch) variables have positive influence on the prediction, but it is much smaller. Embarked feature has very small, negative contribution to the prediction.

Learn more

Acknowledgments

Work on this package was financially supported by the NCBR Grant POIR.01.01.01-00-0328/17 and the NCN Opus grant 2017/27/B/ST6/01307.

triplot's People

Contributors

kasiapekala avatar pbiecek avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

triplot's Issues

triplot() does not work in my case..

Hello everyone

I am using triplot library for feature selection but it give me error.. I simply train a model using mlr3 and create an explainer and then the following :

library(triplot)
model_triplot(explainer)

The error message is:

Error in hclust(as.dist(1 - abs(cor(data, method = cor_method))), method = clust_method) :
NA/NaN/Inf in foreign function call (arg 10)
In addition: Warning message:
In cor(data, method = cor_method) : the standard deviation is zero

before CRAN todo

  • readme - remove collapsible sections
  • aspect_importance() - switch order of parameters sample_method and n_var
  • remove == 1 in glm model for titatnic
  • check predict_aspects() for titanic glm results (readme)

How to use triplot tool with iBrakdown or Shap

Hello

If we are using iBrakdown or Shap for instance-level explanation, how can we make sure that they consider the correlated metrics by producing explanations. I mean if I am usingiBrakdown, for instance, can I use this triplot tool so that it consider the correlated metrics? I read somewhere that without this (considering correlated metrics), the explanation produced is misleading.

Thanks

Problem with many highly correlated features

I have data with highly correlated variables and many of them are clustering together at height 0 at algorithms. The function hierarchical_importance does not work with this example.

Below I provide reproducible example.

library(dplyr)
library(triplot)
library(DALEX)
library(gbm)

## download https://github.com/woznicak/MetaFeaturesImpact/blob/master/summary_results_surrogate_models_rank_per_algo.Rd
## This is the Rdata object with list of explainers

summary_results <- readRDS('summary_results_surrogate_models_rank_per_algo.Rd')

explainer_gbm <- summary_results$explainer_GBM_deep[[11]]

tri_var_imp <- calculate_triplot(explainer_gbm, 
                                 data = explainer_gbm$data,
                                 y = explainer_gbm$y,
                                 new_observation = explainer_gbm$data[1,],
                                 predict_function = sexplainer_gbm$predict_function)

I work around this problem and this is my fixing of function

hierarchical_importance <- function (x, data, y = NULL, predict_function = predict, type = "predict", 
          new_observation = NULL, N = 1000, loss_function = DALEX::loss_root_mean_square, 
          B = 10, fi_type = c("raw", "ratio", "difference"), clust_method = "complete", 
          cor_method = "spearman", ...) 
{
  if (all(type != "predict", is.null(y))) {
    stop("Target is needed for hierarchical_importance calculated at model \n         level")
  }
  fi_type <- match.arg(fi_type)
  x_hc <- hclust(as.dist(1 - abs(cor(data, method = cor_method))), 
                 method = clust_method)
  cutting_heights <- x_hc$height
  # aspects_list_previous <- list_variables(x_hc, 1)
  aspects_list_previous <- as.list(colnames(data))
  int_node_importance <- as.data.frame(NULL)
  for (i in c(1:(length(cutting_heights) - 1))) {
    aspects_list_current <- list_variables(x_hc, 1 - cutting_heights[i])
    t1 <- match(aspects_list_current, setdiff(aspects_list_current, 
                                              aspects_list_previous))
    for(k in na.omit(t1)){
      t2 <- which(t1 == k)
      t3 <- aspects_list_current[t2]
      group_name <- names(t3)
      if (type != "predict") {
        
        explainer <- explain(model = x, data = data, y = y,
                             predict_function = predict_function, verbose = FALSE)
        res_ai <- feature_importance(explainer = explainer,
                                     variable_groups = aspects_list_current, N = N,
                                     loss_function = loss_function, B = B, type = fi_type)
        
        
       
        class(res_ai) <- c("model_parts" ,"feature_importance_explainer", class(res_ai))
        
        
        
        res_ai <- res_ai[res_ai$permutation == "0", ]
        int_node_importance[nrow(int_node_importance) + 1, 1] <- res_ai[res_ai$variable == 
                                                                          group_name, ]$dropout_loss
      }
      else {
        res_ai <- aspect_importance(x = x, data = data, 
                                    predict_function = predict_function, new_observation = new_observation, 
                                    variable_groups = aspects_list_current, N = N)
        int_node_importance[nrow(int_node_importance) +1 , 1] <- res_ai[res_ai$variable_groups == 
                                                                          group_name, ]$importance
      }
      int_node_importance[nrow(int_node_importance), 2] <- group_name
      int_node_importance[nrow(int_node_importance), 3] <- cutting_heights[i]
    }
    
    aspects_list_previous <- aspects_list_current
  }
  if (type != "predict") {
    res <- feature_importance(explainer = explainer, variable_groups = , N = N, loss_function = loss_function, B = B)
    res <- res[res$permutation == "0", ]
    baseline_val <- res[res$variable == "aspect.group1", 
    ]$dropout_loss
    int_node_importance[length(cutting_heights), 1] <- baseline_val
  }
  else {
    int_node_importance[(nrow(int_node_importance)+1):length(cutting_heights), 1] <- NA
  }
  x_hc$height <- int_node_importance$V1
  hi <- list(x_hc, type, new_observation)
  class(hi) <- c("hierarchical_importance")
  return(hi)
}

triplot 1.3

  • workaround for model_triplot and patchwork 1.3 issue

  • new parameter for triplot (correlation method)

  • compatibility with DALEX 1.3 (n_sample parameter)

  • add tests for plots

  • rename plot.predict_aspects parameter (issue #14 )

mlr3 based multiclass problems cannot be handled by triplot

mlr3 based 'probe' type learners for multiclass tasks produce following error:
Error in x[[jj]][iseq] <- vjj : replacement has length zero
In addition: There were 50 or more warnings (use warnings() to see the first 50)

A minimal working example and session info are pasted below. Help requested please.

`
library(tidyverse)
library(mlr3verse)
library(DALEX)
library(DALEXtra)
library(triplot)

df=data.frame(v = c(3.4,5.6,1.3,9.8,7.3, 4.6,5.5,2.3,8.9,7.1, 4.9,6.5,2.3,4.1,3.37, 3.4,6.0,2.3,7.8,3.7),
w = c(34,65,23,78,37, 34,65,23,78,37, 34,65,23,78,37, 34,65,23,78,37),
x.a = c(1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0),
x.b = c(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0),
x.c = c(0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1),
y = c(1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0),
z = c('alpha','alpha','delta','delta','phi', 'alpha','alpha','delta','delta','phi', 'alpha','alpha','delta','delta','phi', 'alpha','alpha','delta','delta','phi')
)

df_task <- TaskClassif$new(id = "my_df", backend = df, target = "z")
lrn_rf=lrn("classif.ranger", predict_type = "prob")
lrn_rf$train(df_task)

lrn_rf_exp <- explain_mlr3(lrn_rf,
data = df[,-7],
y = df$z,
label = "rf_exp")
tri_plts <- model_triplot(lrn_rf_exp)

sessionInfo()
R version 3.6.1 (2019-07-05)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 18362)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252 LC_CTYPE=English_United States.1252
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C
[5] LC_TIME=English_United States.1252

attached base packages:
[1] stats graphics grDevices utils datasets methods base

other attached packages:
[1] rSAFE_0.1.0 mlr3verse_0.1.1 paradox_0.2.0.9000 mlr3viz_0.1.1.9002
[5] mlr3tuning_0.1.2 mlr3pipelines_0.1.3 mlr3learners_0.2.0 mlr3filters_0.2.0
[9] mlr3db_0.1.5 mlr3_0.3.0 forcats_0.5.0 stringr_1.4.0
[13] dplyr_1.0.0 purrr_0.3.4 readr_1.3.1 tidyr_1.1.0
[17] tibble_3.0.1 ggplot2_3.3.2 tidyverse_1.2.1 triplot_1.2.0
[21] DALEXtra_1.3 DALEX_1.3.0

loaded via a namespace (and not attached):
[1] ggdendro_0.1-20 httr_1.4.0 jsonlite_1.6.1 splines_3.6.1 foreach_1.5.0
[6] modelr_0.1.4 assertthat_0.2.1 lgr_0.3.4 ingredients_1.2.0 cellranger_1.1.0
[11] mlr3misc_0.3.0 pillar_1.4.4 backports_1.1.8 lattice_0.20-38 glue_1.4.1
[16] reticulate_1.16 uuid_0.1-4 digest_0.6.25 checkmate_2.0.0 rvest_0.3.4
[21] colorspace_1.4-1 Matrix_1.2-17 pkgconfig_2.0.3 broom_0.5.6 haven_2.3.1
[26] patchwork_1.0.1 scales_1.1.1 ranger_0.11.2 generics_0.0.2 ellipsis_0.3.1
[31] withr_2.2.0 cli_2.0.2 survival_3.1-12 magrittr_1.5 crayon_1.3.4
[36] readxl_1.3.1 fansi_0.4.1 nlme_3.1-140 MASS_7.3-51.4 xml2_1.2.1
[41] tools_3.6.1 data.table_1.12.8 hms_0.5.3 lifecycle_0.2.0 munsell_0.5.0
[46] glmnet_4.0-2 packrat_0.5.0 compiler_3.6.1 rlang_0.4.6 grid_3.6.1
[51] iterators_1.0.12 rstudioapi_0.11 rappdirs_0.3.1 gtable_0.3.0 codetools_0.2-16
[56] sets_1.0-18 R6_2.4.1 gridExtra_2.3 lubridate_1.7.4 utf8_1.1.4
[61] shape_1.4.4 stringi_1.4.6 Rcpp_1.0.4.6 vctrs_0.3.1 tidyselect_1.1.0

`

merge aspect_importance with aspect_importance_single

For simplicty, aspect_importance(..., group_variables = NULL, ...) should calculate aspect_importance_single and aspect_importance_single should be removed.

However, aspect_importance_single provides new_observation column that's necessary for triplot, so it has but addressed before removing aspect_importance_single.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.