Git Product home page Git Product logo

modeloriented / survex Goto Github PK

View Code? Open in Web Editor NEW
88.0 7.0 10.0 316.3 MB

Explainable Machine Learning in Survival Analysis

Home Page: https://modeloriented.github.io/survex

License: GNU General Public License v3.0

R 100.00%
brier-scores cox-regression explainable-ai explainable-machine-learning explainable-ml explanatory-model-analysis interpretable-machine-learning interpretable-ml machine-learning shap

survex's People

Contributors

agosiewska avatar hbaniecki avatar kapsner avatar krzyzinskim avatar mikolajsp avatar pbiecek avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

survex's Issues

Request: Ability to use testing data for performance evaluation

Hi all - really appreciate your excellent work on this much needed package thus far. I am writing to inquire about a potential feature request.

In model_performance() and its helper functions (e.g., surv_model_performance), the "newdata" argument is extracted from the trained model object. It would be very useful to be able to specify an alternative (presumably randomly split) dataframe to be used for performance evaluation (i.e., provide a "test" dataset). Perhaps this "test_data" dataframe could be specified at the explain() step, and then preferentially extracted for use in model_performance().

A separate option (but potentially more complex to implement) would be to build in k-fold cross validation.

Please let me know if I am fundamentally misunderstanding something about your package. Forgive me if this isn't the best place to raise this issue, I'm a github noob. Thanks again for all the great work.

Add `plot.model_profile` for two or more models

plot.model_profile in DALEX/ingredients uses the color parameter to choose a dimension for plot splitting, e.g. by the model (label). This allows for comparing two models in one plot.

library(DALEX)

titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial")
explainer_glm <- explain(titanic_glm_model, data = titanic_imputed, verbose = FALSE)
expl_glm <- model_profile(explainer_glm)

library("ranger")
titanic_ranger_model <- ranger(survived~., data = titanic_imputed, num.trees = 50,
                               probability = TRUE)
explainer_ranger  <- explain(titanic_ranger_model, data = titanic_imputed, verbose = FALSE)
expl_ranger <- model_profile(explainer_ranger)

plot(expl_ranger$agr_profiles, expl_glm$agr_profiles, color = "_label_") 
plot(expl_ranger$agr_profiles, expl_glm$agr_profiles)

It would be nice to have similar functionality in survex

ve_rf <- model_profile(explainer_rf, variables = "age")
plot(ve_rf)
ve_cox <- model_profile(explainer_cox, variables = "age")
plot(ve_cox, ve_rf) # won't work

Firstly, at this moment the above code should return an error

The idea is to add some facet parameter, which by default is set to variables but can be changed to _label_ etc.

plot(ve_cox, ve_rf, facet="_label_")

Mockup:

image

Using:

(plot(ve_rf) + ggtitle("rsf")) +
  (plot(ve_cox) + ggtitle("cph")) + 
  plot_annotation(subtitle="")

Allow the user to remove the subtitle in plots

At this moment, plot functions with the subtitle parameter have it set to NULL as default.

Therefore, it is not possible for the user to remove the subtitle using the conventional subtitle=NULL value.

Perhaps set a different default value? It is clearly not NULL, because the subtitle appears by deafult.

if (is.null(subtitle)) {
glm_labels <- paste0(label, collapse = ", ")
subtitle <- paste0("created for the ", glm_labels, " model")
}

Improve the documentation of API

A few things come to mind:

  • Description of explain()/explain_survival() does not say when to apply which function, while also all examples use explain()

image

  • Add informative descriptions to surv_fi and surv_integrated_fi, e.g. what the function does

image

  • Call x and explainer arguments "explainers"/"explainer objects" not "models"?

image
image

  • Add references (articles, URLs) to model_performance() and metrics/loss functions https://modeloriented.github.io/survex/reference/model_performance.surv_explainer.html

  • Add references (articles, URLs) to SurvLIME/SurvSHAP(t) methods

  • Let's add an information which functions return the objects named in descriptions of plotting functions, e.g. "the surv_shap object returned by the predict_parts(..., type="SurvSHAP(t)") function"

image

different legend titles in `plot.model_profile` between `categorical` and `numerical` plots

For some reason, when plotting model_profile for categorical variables, the legend title disappears

scale_color_manual(name = "",

but for numerical variables, it is set to the variable's name
scale_colour_gradient2(
name = paste0(unique(df$`_vname_`), " value"),

Does copying code from numerical (2) to categorical (1) break something?

errors "Error in Ops.Surv(observed, predicted) : Invalid operation on a survival time"

I used the mlr3 and mlr3proba with surv.xgboost as follows, since xgboost cannot output survival matrix so I used a pipeline,
`xgb_task = TaskSurv$new("train",
backend = data.frame(train), time = "time",
event = "event")

xgb_lrn = as_learner(ppl(c("distrcompositor"),
lrn("surv.xgboost", objective = "survival:cox", nrounds=300L, eta=0.1), form = "ph"))
`

then I try to explain the model, I set the predict_survival_function and predict_cumulative_hazard_function, which could be successfully conducted when I run them,
explainer <- explain(model = xgb_lrn, data = train, y = Surv(time[!fold == i1],event[!fold == i1]), predict_function = function(model,newdata){ predict(model, newdata, predict_type = "<Prediction>")$crank }, predict_survival_function = function(model,newdata,times){ t(predict(model, newdata, predict_type = "<Prediction>")$distr$survival(times)) }, predict_cumulative_hazard_function = function(model,newdata,times){ t(predict(model, newdata, predict_type = "<Prediction>")$distr$cumHazard(times)) } )

Preparation of a new explainer is initiated
-> model label : R6 ( default )
-> data : 1498 rows 62 cols
-> target variable : 1498 values
-> predict function : function(model, newdata) { predict(model, newdata, predict_type = "")$crank }
-> predicted values : No value for predict function target column. ( default )
-> model_info : package mlr3 , ver. 0.14.0 , task regression ( default )
-> predicted values : numerical, min = -7.026592 , mean = -0.8300239 , max = 11.02893
-> residual function : difference between y and yhat ( default )
-> residuals : the residual_function returns an error when executed ( WARNING )

  • I could successfully run predict(explainer,train) to get the risk score

However, I meet error when adding output_type

Error in predict_function(model, newdata, ...) :
unused argument (output_type = "chf")

  • I also meet error when try to explain
    model_parts(explainer)

Error in Ops.Surv(observed, predicted) :
Invalid operation on a survival time

model_performance(explainer)

Error in Ops.Surv(y, predict_function(model, data)) :
Invalid operation on a survival time

I don't know what's wrong with my code.

predict_parts survshap runs out of memory when `p >= 17`

For p=17, this line produces a square matrix of dim [131072, 131072]

W <- diag(shap_kernel_weights)

See rnorm(131072*131072).

Source:

library(survival)
library(survex)
library(ranger)

N <- 100
P <- 17

X <- matrix(rnorm(N*P), ncol=P, nrow=N)
time <- rnorm(N)
status <- rbinom(N, 1, 0.5)
df <- cbind(data.frame(time=time, status=status), data.frame(X))

dim(df)

model <- ranger(Surv(time, status) ~ ., data = df, num.trees = 50)

explainer <- explain(model, data = df[, -c(1, 2)], y = Surv(df$time, df$status))

model_performance(explainer)

predict_parts(explainer, df[1, -c(1, 2)])
# Error: cannot allocate vector of size 128.0 Gb

Parallel analysis for model_parts function

Is there any parallel analysis in the model_parts function? I want to compute the importance of 5K features. The R only uses 10% CPU and 15% memory, I want to make full use of the CPU to save time.

Thank you so much.

BTW, for the m_parts, there are permutations from 0 to 10, which

permutation

used in the plot(m_parts)? I am now thinking the plot use

permutation = 0

Update vignette

  • No information about risk score
  • Maybe change the example to ranger so that unusual creation of explaner is demonstrated

Feature: add support for computing SurvSHAP using treeshap-algorithm

The purpose of this formal feature request is mainly for a documentation of ongoing works and to avoid duplicate efforts.

I've started to work on implementing support for computing SurvSHAP values using the treeshap algorithm in this fork of survex.

As of now, both, local and global computation of SurvSHAP values using the ranger algorithm work technically.

Statistical / mathematical correctness needs yet to be proven for:

  • computation of predictions of the survival probability at different time points using treeshap
    • see this issue
    • related code is located here
    • one possibly required enhancement includes to use the stepfunction logic from here in the treeshap-package to apply it to the predictions from ranger in order to address that currently, only the time-points that are computed internally by ranger are used for computing the survival probabilities (see here), which results in this code in survex and which is some kind of a workaround currently
  • calculation of aggregated "global" survshap values for each time-point and feature across multiple observations

Suggestion on progressr

Hello,

while troubleshooting a false-positive CRAN NOTE, I noticed that you use:

if (requireNamespace("progressr", quietly = TRUE)) prog()

in several places. You can avoid the quite expensive(*) calls to requireNamespace("progressr", quietly = TRUE) if you do:

if (requireNamespace("progressr", quietly = TRUE)) {
  prog <- progressr::progressor(along = 1:((length(variables) + 2) * B))
} else {
  prog <- function() NULL
}

Then you can replace:

if (requireNamespace("progressr", quietly = TRUE)) prog()

with

prog()

(*) requireNamespace() hits the file system each time if progressr is not installed.

default themes

DALEX can now set themes globally by using

DALEX::set_theme_dalex("ema")

to support this you may need to replace theme_drwhy() with DALEX::theme_default_dalex()

note that set_theme_dalex takes two themes for horizontal and vertical plots

ROC AUC for a binary classification problem

It should take into account censoring and the target based on times.

else {
if (is.null(times)) stop("Times cannot be NULL for type `roc`")
rocs <- lapply(times, function(time) {
labels <- 1 - explainer$y[, 2]
scores <- explainer$predict_survival_function(explainer$model, newdata, time)
labels <- labels[order(scores, decreasing = TRUE)]
cbind(time = time, data.frame(TPR = cumsum(labels) / sum(labels), FPR = cumsum(!labels) / sum(!labels), labels))
})
rocs_df <- do.call(rbind, rocs)
class(rocs_df) <- c("surv_model_performance_rocs", class(rocs_df))
attr(rocs_df, "label") <- explainer$label
rocs_df
}

refactor feature importance code

  1. rename permutation to _permutation_, times to _times_ and reference to _reference_ to omit potential problems with feature names
    raw$permutation <- rep(1:B, each = length(times))

    cbind(data.frame(times = times), "_full_model_" = loss_full, loss_variables, "_baseline_" = loss_baseline)
  2. use ret[['_reference_']] instead of ret[, ncol(ret)] for readability and code robustness
    res[, 2:(ncol(res) - 3)] <- res[, 2:(ncol(res) - 3)] / res[, (ncol(res))]
  3. sort the result by _permutation_ after merge() like it is done in type="raw"
    res <- merge(res, res_full, by = "times")

refactor `loss_integrate`

change cumsum(c(0, x))[length(x)+1] to sum(x) in all three scenarios

survex/R/metrics.R

Lines 39 to 55 in 64e96d6

integrated_metric <- cumsum(c(0, tmp))[length(tmp)+1] / (max(times) - min(times))
return(integrated_metric)
}
else if (normalization == "t_max") {
tmp <- (loss_values[1:(n - 1)] + loss_values[2:n]) * diff(times) / 2
integrated_metric <- cumsum(c(0, tmp))[length(tmp)+1]
return(integrated_metric/max(times))
} else if (normalization == "survival"){
kaplan_meier <- survfit(y_true~1)
estimator <- transform_to_stepfunction(kaplan_meier, eval_times = sort(unique(y_true[,1])), type = "survival")
dwt <- 1 - estimator(times)
tmp <- (loss_values[1:(n - 1)] + loss_values[2:n]) * diff(dwt) / 2
integrated_metric <- cumsum(c(0, tmp))[length(tmp)+1]
return(integrated_metric/(1 - estimator(max(times))))

`predict_parts` with mlr3

library(mlr3proba)
library(mlr3extralearners)
library(mlr3pipelines)
library(survex)
library(survival)
veteran_task <- as_task_surv(veteran,
                             time = "time",
                             event = "status",
                             type = "right")
ranger_learner <- lrn("surv.ranger")    
ranger_learner$train(veteran_task)
ranger_learner_explainer <- explain(ranger_learner, 
                                    data = veteran[, -c(3,4)],
                                    y = Surv(veteran$time, veteran$status),
                                    label = "Ranger model")
# error
pp <- predict_parts(ranger_learner_explainer,
                    ranger_learner_explainer$data[1,])

Integration method for time-dependent measures

At this time, for calculation of Integrated Brier Score and Integrated AUC is performed as Riemann integral. This leads to a great influence of times at the end of the considered time-frame (where there might be few observations) on the integrated score.

Two proposed normalization methods are in [1]. We should try to implement (one of/both of) these methods.

There should also be a possibility to cut off the integration at different quantiles of times in the data (Q25, Q50, Q75) [2]

error from explain

Hi,
I got some nasty errors and hopefully authors can check into this.
something odd here: no error in model but in the explain step:
model.src <- randomForestSRC::rfsrc(Surv(OS, status) ~ ., data = data)
explain(model.src)
Error in grep("NODE", lns):(length(lns) - 2) : argument of length 0
In addition: Warning messages:
1: In if (substring(pattern, 1, 4) == "@rm_") { :
the condition has length > 1 and only the first element will be used
2: In if (substring(pattern, 1, 1) == "@") { :
the condition has length > 1 and only the first element will be used

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.