modeloriented / survex Goto Github PK
View Code? Open in Web Editor NEWExplainable Machine Learning in Survival Analysis
Home Page: https://modeloriented.github.io/survex
License: GNU General Public License v3.0
Explainable Machine Learning in Survival Analysis
Home Page: https://modeloriented.github.io/survex
License: GNU General Public License v3.0
See
survex/R/plot_surv_ceteris_paribus.R
Line 172 in 561b62e
survex/R/plot_surv_ceteris_paribus.R
Line 189 in 561b62e
Also, should scales
be used here?
survex/R/plot_surv_ceteris_paribus.R
Line 215 in 561b62e
via explanation_label
parameter
They're all scalar metrics, so we should adapt the input and output to match our needs, and make them possible to use for the users https://mlr3proba.mlr-org.com/#measures
Line 16 in 64e96d6
Hi all - really appreciate your excellent work on this much needed package thus far. I am writing to inquire about a potential feature request.
In model_performance() and its helper functions (e.g., surv_model_performance), the "newdata" argument is extracted from the trained model object. It would be very useful to be able to specify an alternative (presumably randomly split) dataframe to be used for performance evaluation (i.e., provide a "test" dataset). Perhaps this "test_data" dataframe could be specified at the explain() step, and then preferentially extracted for use in model_performance().
A separate option (but potentially more complex to implement) would be to build in k-fold cross validation.
Please let me know if I am fundamentally misunderstanding something about your package. Forgive me if this isn't the best place to raise this issue, I'm a github noob. Thanks again for all the great work.
plot.model_profile
in DALEX/ingredients uses the color
parameter to choose a dimension for plot splitting, e.g. by the model (label). This allows for comparing two models in one plot.
library(DALEX)
titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial")
explainer_glm <- explain(titanic_glm_model, data = titanic_imputed, verbose = FALSE)
expl_glm <- model_profile(explainer_glm)
library("ranger")
titanic_ranger_model <- ranger(survived~., data = titanic_imputed, num.trees = 50,
probability = TRUE)
explainer_ranger <- explain(titanic_ranger_model, data = titanic_imputed, verbose = FALSE)
expl_ranger <- model_profile(explainer_ranger)
plot(expl_ranger$agr_profiles, expl_glm$agr_profiles, color = "_label_")
plot(expl_ranger$agr_profiles, expl_glm$agr_profiles)
It would be nice to have similar functionality in survex
ve_rf <- model_profile(explainer_rf, variables = "age")
plot(ve_rf)
ve_cox <- model_profile(explainer_cox, variables = "age")
plot(ve_cox, ve_rf) # won't work
Firstly, at this moment the above code should return an error
The idea is to add some facet
parameter, which by default is set to variables
but can be changed to _label_
etc.
plot(ve_cox, ve_rf, facet="_label_")
Mockup:
Using:
(plot(ve_rf) + ggtitle("rsf")) +
(plot(ve_cox) + ggtitle("cph")) +
plot_annotation(subtitle="")
Good day, I could not find information about what specific C/D AUC measure was used in the manual (see https://doi.org/10.1186/s12874-017-0332-6). Could you please provide a reference? Thank you for this amazing package.
At this moment, plot functions with the subtitle
parameter have it set to NULL
as default.
Therefore, it is not possible for the user to remove the subtitle using the conventional subtitle=NULL
value.
Perhaps set a different default value? It is clearly not NULL
, because the subtitle appears by deafult.
survex/R/plot_surv_feature_importance.R
Lines 73 to 76 in 561b62e
Perhaps we should change to drWhy
A few things come to mind:
explain()
/explain_survival()
does not say when to apply which function, while also all examples use explain()
surv_fi
and surv_integrated_fi
, e.g. what the function doesx
and explainer
arguments "explainers"/"explainer objects" not "models"? Add references (articles, URLs) to model_performance()
and metrics/loss functions https://modeloriented.github.io/survex/reference/model_performance.surv_explainer.html
Add references (articles, URLs) to SurvLIME/SurvSHAP(t) methods
Let's add an information which functions return the objects named in descriptions of plotting functions, e.g. "the surv_shap
object returned by the predict_parts(..., type="SurvSHAP(t)")
function"
Add information to https://github.com/mi2dataLab/survxai readme that the package won't be maintained.
Missing 0 value, end value.
For some reason, when plotting model_profile
for categorical variables, the legend title disappears
survex/R/plot_surv_ceteris_paribus.R
Line 213 in 561b62e
survex/R/plot_surv_ceteris_paribus.R
Lines 161 to 162 in 561b62e
Does copying code from numerical (2) to categorical (1) break something?
I used the mlr3 and mlr3proba with surv.xgboost as follows, since xgboost cannot output survival matrix so I used a pipeline,
`xgb_task = TaskSurv$new("train",
backend = data.frame(train), time = "time",
event = "event")
xgb_lrn = as_learner(ppl(c("distrcompositor"),
lrn("surv.xgboost", objective = "survival:cox", nrounds=300L, eta=0.1), form = "ph"))
`
then I try to explain the model, I set the predict_survival_function and predict_cumulative_hazard_function, which could be successfully conducted when I run them,
explainer <- explain(model = xgb_lrn, data = train, y = Surv(time[!fold == i1],event[!fold == i1]), predict_function = function(model,newdata){ predict(model, newdata, predict_type = "<Prediction>")$crank }, predict_survival_function = function(model,newdata,times){ t(predict(model, newdata, predict_type = "<Prediction>")$distr$survival(times)) }, predict_cumulative_hazard_function = function(model,newdata,times){ t(predict(model, newdata, predict_type = "<Prediction>")$distr$cumHazard(times)) } )
Preparation of a new explainer is initiated
-> model label : R6 ( default )
-> data : 1498 rows 62 cols
-> target variable : 1498 values
-> predict function : function(model, newdata) { predict(model, newdata, predict_type = "")$crank }
-> predicted values : No value for predict function target column. ( default )
-> model_info : package mlr3 , ver. 0.14.0 , task regression ( default )
-> predicted values : numerical, min = -7.026592 , mean = -0.8300239 , max = 11.02893
-> residual function : difference between y and yhat ( default )
-> residuals : the residual_function returns an error when executed ( WARNING )
predict(explainer,train)
to get the risk scoreHowever, I meet error when adding output_type
Error in predict_function(model, newdata, ...) :
unused argument (output_type = "chf")
model_parts(explainer)
Error in Ops.Surv(observed, predicted) :
Invalid operation on a survival time
model_performance(explainer)
Error in Ops.Surv(y, predict_function(model, data)) :
Invalid operation on a survival time
I don't know what's wrong with my code.
For p=17
, this line produces a square matrix of dim [131072, 131072]
Line 106 in 64e96d6
See rnorm(131072*131072)
.
Source:
library(survival)
library(survex)
library(ranger)
N <- 100
P <- 17
X <- matrix(rnorm(N*P), ncol=P, nrow=N)
time <- rnorm(N)
status <- rbinom(N, 1, 0.5)
df <- cbind(data.frame(time=time, status=status), data.frame(X))
dim(df)
model <- ranger(Surv(time, status) ~ ., data = df, num.trees = 50)
explainer <- explain(model, data = df[, -c(1, 2)], y = Surv(df$time, df$status))
model_performance(explainer)
predict_parts(explainer, df[1, -c(1, 2)])
# Error: cannot allocate vector of size 128.0 Gb
Is there any parallel analysis in the model_parts function? I want to compute the importance of 5K features. The R only uses 10% CPU and 15% memory, I want to make full use of the CPU to save time.
Thank you so much.
BTW, for the m_parts, there are permutations from 0 to 10, which
permutation
used in the plot(m_parts)? I am now thinking the plot use
permutation = 0
The purpose of this formal feature request is mainly for a documentation of ongoing works and to avoid duplicate efforts.
I've started to work on implementing support for computing SurvSHAP values using the treeshap
algorithm in this fork of survex
.
As of now, both, local and global computation of SurvSHAP values using the ranger
algorithm work technically.
Statistical / mathematical correctness needs yet to be proven for:
Hello,
while troubleshooting a false-positive CRAN NOTE, I noticed that you use:
if (requireNamespace("progressr", quietly = TRUE)) prog()
in several places. You can avoid the quite expensive(*) calls to requireNamespace("progressr", quietly = TRUE)
if you do:
if (requireNamespace("progressr", quietly = TRUE)) {
prog <- progressr::progressor(along = 1:((length(variables) + 2) * B))
} else {
prog <- function() NULL
}
Then you can replace:
if (requireNamespace("progressr", quietly = TRUE)) prog()
with
prog()
(*) requireNamespace()
hits the file system each time if progressr is not installed.
DALEX can now set themes globally by using
DALEX::set_theme_dalex("ema")
to support this you may need to replace theme_drwhy()
with DALEX::theme_default_dalex()
note that set_theme_dalex
takes two themes for horizontal and vertical plots
survex/R/plot_model_profile_survival.R
Line 18 in c490bf3
They are no longer arranged with the gridExtra::grid.arrange
It should take into account censoring and the target based on times
.
survex/R/surv_model_performance.R
Lines 41 to 55 in e107cd2
permutation
to _permutation_
, times
to _times_
and reference
to _reference_
to omit potential problems with feature names survex/R/surv_feature_importance.R
Line 160 in 64e96d6
survex/R/surv_feature_importance.R
Line 156 in 64e96d6
ret[['_reference_']]
instead of ret[, ncol(ret)]
for readability and code robustness survex/R/surv_feature_importance.R
Line 175 in 64e96d6
_permutation_
after merge()
like it is done in type="raw"
survex/R/surv_feature_importance.R
Line 170 in 64e96d6
Hi,
it would be great to have a vignette similar to https://modeloriented.github.io/survex/articles/mlr3proba-usage.html
but with examples for
https://journal.r-project.org/archive/2018/RJ-2018-005/RJ-2018-005.pdf
and
https://journal.r-project.org/archive/2020/RJ-2020-018/RJ-2020-018.pdf
change cumsum(c(0, x))[length(x)+1]
to sum(x)
in all three scenarios
Lines 39 to 55 in 64e96d6
library(mlr3proba)
library(mlr3extralearners)
library(mlr3pipelines)
library(survex)
library(survival)
veteran_task <- as_task_surv(veteran,
time = "time",
event = "status",
type = "right")
ranger_learner <- lrn("surv.ranger")
ranger_learner$train(veteran_task)
ranger_learner_explainer <- explain(ranger_learner,
data = veteran[, -c(3,4)],
y = Surv(veteran$time, veteran$status),
label = "Ranger model")
# error
pp <- predict_parts(ranger_learner_explainer,
ranger_learner_explainer$data[1,])
At this time, for calculation of Integrated Brier Score and Integrated AUC is performed as Riemann integral. This leads to a great influence of times at the end of the considered time-frame (where there might be few observations) on the integrated score.
Two proposed normalization methods are in [1]. We should try to implement (one of/both of) these methods.
There should also be a possibility to cut off the integration at different quantiles of times in the data (Q25, Q50, Q75) [2]
Line 31 in 64e96d6
Check if kernelshap
package can be used for SurvSHAP(t) calculation and implement it if possible.
Hi,
I got some nasty errors and hopefully authors can check into this.
something odd here: no error in model but in the explain step:
model.src <- randomForestSRC::rfsrc(Surv(OS, status) ~ ., data = data)
explain(model.src)
Error in grep("NODE", lns):(length(lns) - 2) : argument of length 0
In addition: Warning messages:
1: In if (substring(pattern, 1, 4) == "@rm_") { :
the condition has length > 1 and only the first element will be used
2: In if (substring(pattern, 1, 1) == "@") { :
the condition has length > 1 and only the first element will be used
...like in time-dependent permutational feature importance plot
Lines 76 to 78 in 64e96d6
Lines 34 to 35 in 64e96d6
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.