 Research
 Open access
 Published:
A transformer model for causespecific hazard prediction
BMC Bioinformatics volume 25, Article number: 175 (2024)
Abstract
Backgroud
Modelling discretetime causespecific hazards in the presence of competing events and nonproportional hazards is a challenging task in many domains. Survival analysis in longitudinal cohorts often requires such models; notably when the data is gathered at discrete points in time and the predicted events display complex dynamics. Current models often rely on strong assumptions of proportional hazards, that is rarely verified in practice; or do not handle sequential data in a meaningful way. This study proposes a Transformer architecture for the prediction of causespecific hazards in discretetime competing risks. Contrary to Multilayer perceptrons that were already used for this task (DeepHit), the Transformer architecture is especially suited for handling complex relationships in sequential data, having displayed stateoftheart performance in numerous tasks with few underlying assumptions on the task at hand.
Results
Using synthetic datasets of 2000–50,000 patients, we showed that our Transformer model surpassed the CoxPH, PyDTS, and DeepHit models for the prediction of causespecific hazard, especially when the proportional assumption did not hold. The error along simulated time outlined the ability of our model to anticipate the evolution of causespecific hazards at later time steps where few events are observed. It was also superior to current models for prediction of dementia and other psychiatric conditions in the English longitudinal study of ageing cohort using the integrated brier score and the timedependent concordance index. We also displayed the explainability of our model’s prediction using the integrated gradients method.
Conclusions
Our model provided stateoftheart prediction of causespecific hazards, without adopting prior parametric assumptions on the hazard rates. It outperformed other models in nonproportional hazards settings for both the synthetic dataset and the longitudinal cohort study. We also observed that basic models such as CoxPH were more suited to extremely simple settings than deep learning models. Our model is therefore especially suited for survival analysis on longitudinal cohorts with complex dynamics of the covariatetooutcome relationship, which are common in clinical practice. The integrated gradients provided the importance scores of input variables, which indicated variables guiding the model in its prediction. This model is ready to be utilized for timetoevent prediction in longitudinal cohorts.
Introduction
Survival analysis under competing risks describes the time of occurrence of the first of several possible outcomes. This can be done by predicting the causespecific hazards from a set of explanatory variables, also called covariates. Competing risks have countless applications in a system’s failure time, which includes client churn and probability of a borrower defaulting on a loan [1, 2]. In medicine, modelling competing events can be used to measure the timetoevent on several possible outcomes such as treatment effects on a patient or the prediction of the time of death after colon cancer diagnosis [3, 4].
Previous work was done on the prediction of causespecific hazards under competing risks. Firstly, the semiparametric Cox proportional hazards (CoxPH) model was introduced for survival analysis under the assumption of proportional hazards, namely a linear relationship between the loghazard ratio and the covariates [5]. Because the original CoxPH model failed in the context of variable collinearity when applied to highly dimensional data, the Regularized CoxPH (RCoxPH) was introduced. This model minimizes CoxPH’s partial likelihood function with an additional elastic net penality [6]. This model had numerous uses, such as the identification of breast cancer prognosis markers [7]. Secondly, a collaspsed loglikelihood approach was developed and applied to colon cancer data [4]. This method does not rely on the proportional hazards assumptions of the CoxPH model, which improved its applicability to realworld data. It was recently implemented as a Python package in PyDTS [8]. Lastly, several studies used deep learning models to minimize a loss function adapted to datasets with censored data [9]. Multi layer perceptron models outperformed previous models in both continuous (DeepSurv) and discrete time (DeepHit) [10, 11]. These deep learning models are able to learn without strong assumptions on the predicted hazard rates; however, they were not initially designed to handle temporal covariates or produce temporal predictions, which limits their performance in survival analysis on longitudinal cohorts.
Additionally, several studies reported on the failure of the proportional hazard assumption in survival analysis, notably for treatment response and oncology [12,13,14,15]. This highlights the need for modelling competing risks with nonproportional hazards.
In various tasks involving sequential data, such as natural language processing and time series forecasting, Transformerbased models demonstrated excellent performance in learning complex dynamics from sequential data [16, 17]. Transformer models are particularly suited for sequence generation, which motivated their use in time series predictions of discrete time causespecific hazards. A Transformer model was recently used for survival analysis with a single event [18]. In this study, we introduce a Transformerbased deep learning model for the prediction of the causespecific hazards in discrete time under competing risks.
Because the true datagenerating mechanisms that entail targeted causespecific hazards are unknown in practice, we used synthetic data to compare our model against three stateoftheart models [19]. We followed the ADEMP guidelines (Aims, Datagenerating mechanisms, Estimands, Methods, and Performance Measures) for simulation and reporting of results [20]. We then validated our model on the English longitudinal study of ageing (ELSA) dataset for the prediction of death, dementia and psychiatric conditions [21]. To our knowledge, this is the first study to use a Transformerbased model for the prediction of the causespecific hazards in discretetime under competing risks.
This article is organized as follows: in “Methodology” section describes our Transformerbased model, the benchmark models, as well as the simulated and ELSA datasets; in “Discussion” section presents the predictive performance of each model on the synthetic and ELSA datasets; finally in “Conclusions” section, we discuss the present conclusions of this study.
Our codes are openly available at https://github.com/USMCHUFGuyon/cause_specific_hazard_transformer.
Methodology
Notations
Competing risks analysis considers a patient described by a vector of covariates X, that may experience one of E separate events on a [0, T] period of time. A patient may be censored at \(t_0 \le T\), in which case it is only known that no event occurred before \(t_0\). For convenience, competing events were denoted \(\{1, \dots , E\}\). If event e occurred at time t, the outcome is written (e, t) with \(e \in \{0, 1, \dots , E\}\), \(t \in [0, T]\), and \(e=0\) indicating censoring.
The causespecific hazard \(\lambda _{e, X}(t)\), for \(e\ge 1\), defined by (1) is the instantaneous rate of occurrence of event e at time t, given that the patient remained eventfree until t. A model of causespecific hazard explores the relation between covariates X and the causespecific hazard \(\lambda _{e, X}\) for each event e [22].
Note that in discretetime competing risks, the causespecific hazard is defined as a probability and not as an unbounded positive number [23]. We also introduce the cumulative incidence function (2). This is a function of the causespecific hazard that describes the proportion of patients that experienced event e up until time t.
where \(i_{e, X}\) is the incidence function defined by:
The goal of this study is to build a prediction model for the causespecific hazards \((\lambda _{e, X})_{e \in \{1, \ldots , E\}}\) from a set of covariates X. This study focused on the causespecific hazard but did not explore the prediction of the subdistribution hazard. In the following, X may be constant or longitudinal data.
A transformerbased model for causespecific hazard prediction in discrete time
We used a Transformerbased deep learning model to predict the causespecific hazard \(\lambda _{e, X}\) of each event e from covariates X. This section describes the input and output data, the loss function that was minimized and the model architecture.
Input and output data
In realworld applications, the causespecific hazards are unknown. The available data are the covariates X and outcomes (e, t) where e is the experienced event—or censoring—and t the timetoevent. Our model predicts the causespecific hazards \(\lambda _{e, X}\) of events e from the covariates X as a time series of length T. The output of the model may be written as matrix (4).
Loss function
The collapsed loglikelihood (5) from the PyDTS package was used as a loss function [8]. This function evaluates the consistency between the predicted causespecific hazards \(\lambda _{X=x}\) and the observed outcome \((e_x,t_x)\).
where
Minimizing this loss encourages:

A high value of \(\lambda _{e,t}(x)\); which represents the predicted hazard for the observed outcome \((e_x, t_x)\)

Low values of \(\lambda _{j,k}(x)\) for \((j,k) \ne (e_x,t_x)\); which represent the predicted hazard for outcomes that were not observed
Note that a patient censored at \(t_{x}\) will contribute to low values of \(\lambda _{j,k}(x)\) for each event j and each time \(k < t_{x}\).
Transformerbased model architecture
The Transformer model is a sequencetosequence architecture that was introduced as a response to the vanishinggradients problem that faced long shortterm memory (LSTM) and other recurrent neural networks [24]. It utilizes the selfattention mechanism in an encoder–decoder architecture to learn complex temporal features of input and/or output data. They are especially suited for producing meaningful sequential output, which initially motivated their use for NLP tasks. A gentle introduction to the Transformer architecture is provided in Appendix 1. Consequently, the Transformer architecture also proved to be efficient for time series prediction from sequential or constant input data.
Our model architecture is presented in Fig. 1. It is based on a Transformer encoder, and a linear decoder to predict causespecific hazards as a time series for each event. An input vector of covariates X is encoded by a linear layer and concatenated with an embedding of time. A positional encoding is summed to the obtained tensor, and fed to the Transformer encoder that outputs a single time series of length \(E \times T\). This time series is then decoded into a matrix of shape (E, T) by a single linear layer. The loss function (5) ensures that the model learns to predict causespecific hazards. This model was implemented using the Pytorch framework.
Performance evaluation
Benchmark models
The performance of our Transformerbased model in predicting causespecific hazards was compared to three existing models.
Firstly, we used the semiparametric RCoxPH model from the lifelines package in Python [25]. Secondly, we used the PyDTS model from Lee and al. [4, 8]. Finally, we implemented a model equivalent to the original DeepHit model using the Pytorch framework [11]. This contains a feed forward subnetwork with one hidden linear layer for each competing event and minimizes the loss function (5). All models predicted a timediscretized causespecific hazard for each competing event in the form of a \(E\times T\) matrix, as presented in (4).
Benchmark designs
We evaluated all models using the same experimental setup, for both the synthetic and ELSA data. Data was split as 80% for training and 20% for validation. As described in “Loss function” section, models learned to predict patients’ causespecific hazard for each competing event by learning from observed events in the training data. Both deep learning models had 64neurons hidden layers and no dropout.
Additional implementation details are available in our code repository.
Synthetic data benchmark
We simulated populations of 2000—50,000 patients described by five covariates and susceptible to experience three competing events. Their covariates were independent and uniformly distributed between 0 and 1. Events were drawn using causespecific hazard functions defined in Table 5 from Appendix. Cumulative incidences of each event, and the number of patients at risk at each time step are illustrated in Fig. 2a. Note that one of the simulated events’ hazard was proportional and the other two were nonproportional. Departure from proportional hazard hypothesis is common in clinical data, but represents a strong limitation for most survival analysis models [12].
Finally, censoring times were drawn uniformly between 1 and 49. A patient was censored if the drawn censoring time was anterior to the drawn event. Events (and censoring) were drawn 10 times separately, training and evaluation were done on each drawn dataset to measure performance variability.
In this synthetic experiment, ground truth causespecific hazards are known. For this reason, model predictions were evaluated on the mean absolute error of the causespecific hazard prediction. We also evaluated the models’ predictive performance along simulated time, and with varying training sample size.
ELSA data benchmark The ELSA dataset is a representative cohort of the English population older than 50. It features economic, social, psychological, cognitive, health, biological and genetic data [21]. This longitudinal study currently features 9 waves of data acquired over 18 years and includes various diagnoses of cardiovascular, ocular, and psychiatric diseases.
We used this longitudinal cohort to evaluate the models’ prediction of dementia and psychiatric conditions. The ELSA dataset refers to a psychiatric condition for any of the following psychiatric disorders: hallucinations, anxiety, depression, emotional problems, schizophrenia, psychosis, mood swings, and manic depression. Our study population was the cohort from wave 2 that started in 2004. Patients already diagnosed for a psychiatric condition or dementia were excluded. Because mortality data was last updated in 2012, the study period was 2004–2012. We evaluated the models on the following competing events:

Dementia new diagnosis of dementia

Psychiatric condition new diagnosis of a psychiatric condition

Death
Contrary to our synthetic dataset, the groundtruth for the causespecific hazard is unknown; hence, models were evaluated on the Integrated Brier Score and Timedependent Concordance Index for each event [26, 27]. The Brier Score is a generalization of the mean absolute error applied to the comparison of predicted probabilities and observed event. The Concordance Index is a generalization of the area under receiver operating characteristic (AUROC), it evaluates the ranking of failure times from the predicted probabilities [28]. The Integrated Brier Score and Timedependent concordance index are respective variants of the brier score and concordance index adapted to the prediction of time series. The mean error and \(95\%\) confidence intervals were computed by bootstrapping on the test dataset. Finally, the assumption of proportional hazards was evaluated by computing the p values of the Schoenfeld residuals from the RCoxPH model [29].
We used the Integrated Gradients method on both deep learning models to provide an importance score for the input features [30]. This method provides importance scores with a lower computational cost than Shapley values when applied with a large number of input variables and time series output. In this work, we present the total importance scores over the whole ELSA dataset; however, these scores are available at each prediction. Such importance scores were shown to improve to the usability of artificial intelligence in clinical practice [31].
Results
Evaluation on synthetic data
Simulated data
We simulated datasets of sample sizes of 2000, 5000, 10,000, 20,000, and 50,000 patients each described by 5 covariates and susceptible to experience one of 3 competing events during a period of 30 timesteps. In total, approximately \(40\%\) of patients were censored.
A sample of simulated causespecific hazards for each event are shown on Fig. 2a. We introduced three simulated events: a Proportional hazard event that had constant hazard in time, and two nonproportional hazard events: denoted the Increasing hazard and Nonmonotonic hazard events which featured a temporal evolution with a nonlinear dependence on the covariates. The Nonmonotonic hazard event had a bellcurve distribution where parameters of mean and standard deviation depended on patients’ covariates (see Table 5 from Appendix).
Figure 2b shows the cumulative incidence of each of the three events over the simulated time. We noted that fewer events were observed at the later timesteps of the simulated time due to a smaller number at risk.
Performance comparison
The mean absolute error of the causespecific hazard prediction for several sizes of synthetic datasets is presented in Table 1. The Transformerbased model outperformed or equalled other models on nonproportional hazard events for all dataset sizes, and was better or equivalent to other models on the Proportional hazard event with training data \(> {5000}\) patients. These results highlights a strong performance improvement when using deep learning models on nonproportional events, moreover, the benefit of the Transformer compared to the DeepHit model was more pronounced on smaller dataset sizes. Additionally, Fig. 3 shows the mean absolute error of the causespecific hazard predictions as a function of time. Our Transformer model had better performance on Proportional hazard event despite a lower precision at early time steps of this hazards predictions. We observed that our Transformerbased model always had a large benefit towards the end of the simulated timeframe, which indicates a better ability to extrapolate causespecific hazards from the set of observed events. We also noted that the PyDTS and RCoxPH models had extremely poor performance on the later part of the simulated time where fewer events were observed. This was true for the Proportional hazard event, but even more pronounced for nonproportional hazard events.
Evaluation on the ELSA dataset
Collected data
The cohort size was 3564 patients. We selected 74 variables of which 54 were binary. Over the 8year study period, there were 542 diagnoses of psychiatric conditions, 150 diagnoses of dementia, and 499 recordings of death. Cumulative incidences of each event are illustrated in Fig. 2c. The list of selected variables is shown in Table 6 from Appendix. Some variables had a large number of missing values—up to 45%—and 22 variables had more than 10% missing values. The missing values were imputed using the median value for the continuous variables, and the most frequent value for binary variables. Because evaluated models other than the Transformer and RCoxPH models do not inherently support sequential input data, we used singletonlength input data to provide a fair comparison between all models. All models learnt from input singletonlength sequences and produced causespecific hazard predictions as a fixedlength time series.
Performance comparison
Integrated Brier scores and Timedependent Concordance Index for each model are presented in Table 2. The mean value and 95% confidence interval were obtained by bootstrapping on the test dataset. Our Transformerbased model had the best Integrated Brier Score and Timedependent Concordance Index. Moreover, the PyDTS model was slighlty better than the RCoxPH model, but in comparison, the Transformer model allowed for a major improvement on both metrics. Finally, despite a strong Integrated Brier Score, the DeepHit model showed a poor Concordance index on the ELSA dataset.
Feature importance
The most important features on average for the prediction of each event by the DeepHit and Transformer models are shown on Fig. 4. See Table 6 from Appendix for details on each feature. The age feature was the most important feature for the Transformer model’s predictions. In the prediction of death, the Transformer model notably used the binary features limiting illness and cancer, which stated, respectively, ”Whether limited by longtime illness” and ”Ever diagnosed with cancer”. In the Transformer model predictions, happy mood only appeared among the important features of psychiatric condition and dementia predictions.
Proportional hazard assumption
Variables that broke the proportional hazard assumption are shown in Table 3. This table lists the variables of each dataset where Schoenfeld residuals of the fitted RCoxPH model had p values lower than 0.05. In the synthetic dataset none of the five variables broke the proportional hazard assumption for the Proportional hazard event, whereas the Increasing hazard event and Nonmonotonic hazard event had respectively five and four variables breaking the proportional hazard assumption. Events from the ELSA dataset had four to six Schoenfeld residuals with p values lower than 0.05. This indicates that the Death, Psychiatric condition, and Dementia events had nonproportional hazard rates.
Discussion
We introduced a Transformerbased deep learning model for the prediction of causespecific hazards in the context of discretetime competing risks. This model provides stateoftheart hazard prediction without strong assumption on the relation between covariates and causespecific hazard. It strongly outperformed current models even with relatively small training datasets, and was especially successful on events with highly nonproportional hazards or few observed outcomes. We noted that basic models could perform better in a simplistic setting of timeindependent proportional hazard with a small training sample; however our Transformer model was generally the best for proportional hazards too.
Our Transformerbased model had the best predictive performance of the causespecific hazard for sizes simulated datasets ranging from 5000 to 50,000. It also had the best Integrated Brier score and Timedependent Concordance index on the prediction of three competing events from the ELSA dataset. The experiment on simulated data showed that our model notably outperformed other models in predicting the causespecific hazards at later time steps where fewer outcomes were observed. This resulted in improved performance on the hazard prediction of rare events, a key benefit of our model. Such behaviour could be expected because of the ability of the Transformer architecture to learn and extrapolate complex temporal features from input data and generate coherent timeseries.
The analysis of the proportional hazard assumption on the synthetic data showed that only the Proportional hazard event had a proportional hazard rate. This was consistent with the definition of each event. The same analysis on the ELSA dataset indicated that all three events had nonproportional hazards, which is consistent with other findings of departure from the proportional hazard assumption in clinical data [12,13,14,15]. As a result, in both the synthetic and ELSA datasets, our model strongly outperformed current models on all events featuring nonproportional hazard rates.
Moreover, our model outperformed the DeepHit model on nonproportional hazard by a larger margin for synthetic datasets with sample sizes of 2000–10,000. This indicates that the Transformer model has a better generalization from limited data. Such results greatly increase the usability of our model on relatively small datasets such as ELSA and most longitudinal cohorts. Additionally, the interpretability through integrated gradients provided the main features that affected the result of a prediction. This can be used by clinicians to ensure trust in the model’s prediction, and focus their attention on features that it deemed most relevant. This is critical for clinical use of any machine learning model as no decisionmaking ought to be based on a nonexplainable prediction.
Some limitations remain in our study. Firstly, our model has a large number of parameters unlike the RCoxPH and PyDTS models. While nonoptimized parameters already outperform other models, finetuning the network size and training parameters could improve performance. Secondly, our Transformerbased model was consistently better than the simpler architecture of the DeepHit model. However, the gain in performance came with a higher computational cost. This was not limiting in our study as the training times did not exceed several minutes. Finally, to provide a fair comparison between models, only singletonlength input sequences were utilized in the data examples, as models other than the RCoxPH and Transformer were not designed for handling sequential input. This experiment did demonstrate the ability of the Transformer model to generate meaningful sequences, but did not take benefit from its ability to understand complex dynamics of input sequences.
Conclusions
This study introduces a Transformerbased deep learning model with stateoftheart performance on the causespecific hazard prediction in the context of discretetime competing risks. Our model outperformed current models in causespecific hazard prediction especially for nonproportional hazard rates and few observed outcomes. It had an increased benefit compared to current models for datasets of 2000–50,000 patients. The designs where our model shows greater benefits encompass those of most clinical survival analysis studies on longitudinal cohorts. Our Transformerbased model is ready to be used for improving current hazard predictions on longitudinal cohorts with complex covariatetooutcome dynamics.
Availability of data and materials
Our codes and simulated data are openly available at https://github.com/USMCHUFGuyon/cause_specific_hazard_transformer
Abbreviations
 ELSA:

English longitudinal study of ageing
 LSTM:

Long shortterm memory
 RCoxPH:

Regularized Cox proportional hazards model
References
Routh P, Roy A, Meyer J. Estimating customer churn under competing risks. J Oper Res Soc. 2020;72(1–18):08.
Wycinka E. Competing risk models of default in the presence of early repayments. Econometrics. 2019;23:06.
Cope S, Jansen J. Quantitative summaries of treatment effect estimates obtained with network metaanalysis of survival curves to inform decisionmaking. BMC Med Res Methodol. 2013;13(147):12.
Lee M, Feuer EJ, Fine JP. On the analysis of discrete time competing risks data. Biometrics. 2018;74(4):1468–81.
Cox DR. Regression models and lifetables. J R Stat Soc Ser B (Methodol). 1972;34(2):187–202.
Liu C, Liang Y, Luan XZ, Leung KS, Chan TM, Xu ZB, Zhang H. The l1/2 regularization method for variable selection in the cox model. Appl Soft Comput. 2014;14:498–503.
Li L, Liu ZP. Detecting prognostic biomarkers of breast cancer by regularized cox proportional hazards models. J Transl Med. 2021;19:12.
Meir T, Gutman R, Gorfine M. PyDTS: a python package for discretetime survival (regularized) regression with competing risks. 2022. arXiv eprints, arXiv:2204.05731
Steingrimsson JA, Morrison S. Deep learning for survival outcomes. Stat Med. 2020;39(17):2339–49.
Katzman JL, Shaham U, Cloninger A, Bates J, Jiang T, Kluger Y. DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Med Res Methodol. 2018;18(1):02.
Lee C, Zame W, Yoon J, van der Schaar M. Deephit: A deep learning approach to survival analysis with competing risks. In: Proceedings of the AAAI conference on artificial intelligence, vol. 32, no. 1. 2018;p. 04.
Trinquart L, Jacot J, Conner SC. Porcher R Comparison of treatment effects measured by the hazard ratio and by the ratio of restricted mean survival times in oncology randomized controlled trials. J Clin Oncol Off J Am Soc Clin Oncol. 2016;34:02.
Jiménez J. Quantifying treatment differences in confirmatory trials under nonproportional hazards. J Appl Stat. 2020;49(1–19):09.
Diao G, Ibrahim J. Quantifying timevarying causespecific hazard and subdistribution hazard ratios with competing risks data. Clin Trials (Lond, Engl). 2019;16:06.
Van Wijk RC, Simonsson USH. Finding the right hazard function for timetoevent modeling: a tutorial and shiny application. CPT Pharm Syst Pharmacol. 2022;11(8):991–1001.
Wolf T, Debut L, Sanh V, Chaumond J, Delangue C, Moi A, Cistac P, Rault T, Louf R, Funtowicz M, Brew J. Huggingface’s transformers: stateoftheart natural language processing. CoRR. 2019. arXiv:1910.03771.
Wu N, Green B, Ben X, O’Banion S. Deep transformer models for time series forecasting: the influenza prevalence case. CoRR. 2020. arXiv:2001.08317.
Lin J, Luo S. Deep learning for the dynamic prediction of multivariate longitudinal and survival data. Stat Med. 2022;41(15):2894–907.
Boulesteix AL, Groenwold RHH, Abrahamowicz M, Binder H, Briel M, Hornung R, Morris TP, Rahnenführer J, Sauerbrei W. Introduction to statistical simulations in health research. BMJ Open. 2020;10(12): e039921.
Morris T, White I, Crowther M. Using simulation studies to evaluate statistical methods. Stat Med. 2019;38:01.
Steptoe A, Breeze E, Banks J, Nazroo J. Cohort profile: the English longitudinal study of ageing. Int J Epidemiol. 2012;42:11.
Austin P, Fine J. Practical recommendations for reporting finegray model analyses for competing risk data. Stat Med. 2017;36:09.
Schmid M, Berger M. Competing risks analysis for discrete timetoevent data. WIREs Comput Stat. 2021;13(5): e1529.
Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, GomezAidan N, Kaiser L, Polosukhin I. Attention is all you need. CoRR. 2017. arXiv:1706.03762.
DavidsonPilon C. lifelines: survival analysis in python. J Open Source Softw. 2019;4(40):1317.
Graf E, Schmoor C, Sauerbrei W, Schumacher M. Assessment and comparison of prognostic classification schemes for survival data. Stat Med. 1999;18(17–18):2529–45.
Antolini L, Boracchi P, Biganzoli E. A timedependent discrimination index for survival data. Stat Med. 2005;24(3927–44):12.
Hajime U, Tianxi C, Michael P, Ralph DA, Leejen W. On the cstatistics for evaluating overall adequacy of risk prediction procedures with censored survival data. Stat Med. 2011;30(1105–17):05.
Gill R, Schumacher M. On a simple test of the proportional hazards model. Biometrika. 1987;74:289–300.
Sundararajan M, Taly A, Yan Q. Axiomatic attribution for deep networks. CoRR. 2017. arXiv:1703.01365.
Magboo MSA, Magboo VPC. Feature importance measures as explanation for classification applied to hospital readmission prediction. Procedia Comput Sci. 2022;207:1388–97.
Wen Q, Zhou T, Zhang C, Chen W, Ma Z, Yan J, Sun L. Transformers in time series: a survey. 2023. arXiv:2202.07125
Acknowledgements
We thank Andrew Hobson for his editorial assistance and Laẽtitia Berly for her contribution to this research.
Funding
None.
Author information
Authors and Affiliations
Contributions
MO performed the formal analysis, investigation, data curation, software, visualization and wrote the original draft. CF undertook the methodology, conceptualization, investigation, project administration, resources, supervision, validation, and reviewed the manuscript. MD, NA, and JA carried out the conceptualization, investigation, supervision, project administration, supervision, validation and reviewed the manuscript.
Corresponding author
Ethics declarations
Ethics approval and consent to participate
Not applicable.
Consent for publication
Not applicable.
Competing interests
The authors declare no competing interests.
Additional information
Publisher's Note
Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Appendices
Appendix 1: Introduction to transformer models
Transformers, introduced by Vaswani et al. [24] have become the goto architecture for sequencetosequence tasks. As shown in Fig. 1, input sequences go through the following stack of modules: Embedding, Transformer Encoder, Linear Decoder. This section provides a qualitative explanation, along with a more detailed description of each module.
Embedding Embedding adds temporal information to the input sequences. This allows following blocks to process the embedded sequences as a temporal sequence rather than a unordered set of values.
Transformer encoder The Transformer Encoder uses the attention mechanism to extract the information relevant to the prediction task. By learning attention scores, it encodes the input sequences into a vector that depends solely on relevant temporal information from the input sequences. Encoding this vector provides a lowerdimension representation of the input sequences that is easier to process for the prediction task.
Linear decoder Vectors encoded by the Transformer Encoder can be decoded into the final prediction using a linear network. This is a simple architecture that processes the input vector using a set of trained weights in a single neuron layer.
Appendix 1.1: Embedding
Contrary to other recurrent neural networks, the Transformer architecture do not inherently understand temporality of input sequences. The aim of the embedding step is to learn a representation of the input sequences that contains temporal information [32]. The following operations are applied:

1.
Input sequences X are embedded using a feedforward network: we denote this embedding \(X_{emb} = XF_1^*\) where \(F_1^*\) denotes the trained weights for input sequences embedding. This embedding is a representation of the input vector in a slightly lower dimension.

2.
A time embedding is then concatenated to the embedded time series \(X_{emb}\):
$$\begin{aligned}&T_{emb} = TF_{time} \\&X_{time  emb} = X_{emb} \oplus T_{emb} \end{aligned}$$where T are the timesteps of the input sequences X, \(F_{time}\) is the operator of timeembedding, and \(\oplus\) denotes concatenation. We call timeembedded sequences the tensor \(X_{time  emb}\).

3.
Positional encoding is applied then summed to the timeembedded sequences. In Transformer models, the positional encoding operator (PE) is usually defined as such:
$$\begin{aligned} PE(i, pos) = \sin \left( \frac{pos}{10{,}000^{i/d}}\right) \quad {\text{when}}\; i\; {\text{is even,}}\\ PE(i, pos) = \cos \left( \frac{pos}{10{,}000^{i/d}}\right) \quad {\text{when}}\; i\; {\text{is odd,}} \end{aligned}$$where i is the index of the time series, pos is the position of the element, and d the dimensionality of the embedding. This positional encoding operator is applied on the first axis of \(X_{time emb}\), i.e. identically for all patients. It produces a tensor of same shape as the input embedding. The positionnalyencoded embedding \(X_{pe}\) of input sequences X is
$$\begin{aligned} X_{pe} = {{\,\textrm{tanh}\,}}(X_{time  emb})(1+PE) \end{aligned}$$The reason for summing the positional encoding to the timeembedded sequences is to preserve the dimensionality of the embedded space, while adding the temporal information to the sequence.
The positionallyencoded sequences \(X_{pe}\) are a representation of the input sequences, that include temporal information about the timesteps of measure of the input variables. This tensor is the input of the Transformer encoder. In the following \(X_{pe}\) is called the embedding of X.
Appendix 1.2: Transformer encoder
The Transformer encoder is the crux of the Transformer architecture. It features a multihead attention module followed by layer normalization and a linear layer.
In this work, we used a number of attention heads \(n_{head} = 1\) and an embedding dimension \(n_{lat} = 64\). A single attention head h contains three subnetworks \(Q_h\), \(K_h\), and \(V_h\) respectively called Query, Key, and Value subnetwork. Their respective trained weights are denoted \(Q^*_h\), \(K^*_h\), and \(V^*_h\). An attention head h computes the attention of each element x using its embedding \(x_{pe}\) and the embedding \(X_{pe}\) of the input sequences X.

1.
The embedding of x is fed to the Query subnetwork which outputs \(q_{x,h} = x_{pe}Q^*_h\)

2.
The embedding of the input sequences X is fed to both the Key and Value subnetworks, which respectively output \(k_{X,h} = X_eK^*_h\), and \(v_{X,h} = X_eV^*_h\)

3.
The attention score of the element is given by
$$\begin{aligned} a_{x,X,h} = q_{x,h}* k_{X,h} \end{aligned}$$ 
4.
The element’s attention output \(A_{x,X,h}\) is obtained by weighting \(v_{X,h}\) with a function of the attention score \(a_{x,X,h}\):
$$\begin{aligned} A_{x,X,h} = {{\,\textrm{softmax}\,}}\left( \frac{a_{x,h}}{\sqrt{n_{dim}}}\right) * v_{X,h} \end{aligned}$$
The output of the multihead attention module is a weighted sum of each head’s attention: \(A_{x,X} = \Sigma A_{x,X,h}w_h\) where \(W = [w_1, \dots w_h]\) is a trained parameter.
The concatenation of all elements’ attention yields the attention matrix:
Attention captures complex relationships between a number of input sequences. It weights the informativeness of each input sequence within the context of the whole input sequences. The subsequent normalization and feed forward networks use the attention matrix to produce a lower dimension latent representation of the input sequences. Weaklyinformative elements of the input sequences, eg. highly correlated other input sequences, will obtain a low attention value and will scarcely contribute to the latent representation.
Attention and embeddings of the input sequences are then given to a feedforward encoder to produce the final latent representation \(X_l\).
In short, this attention mechanism allows generating a latent representation of large and complex input sequences by effectively compressing embeddings of the input sequences in a way that preserves informative values and their temporality.
Appendix 1.3: Linear decoder
The feedforward decoder uses the latent representation \(X_l\) for the prediction task at hand. The predicted values are \(P = X_lF_{dec}^*\) where \(F_{dec}^*\) are the trained weights of the decoder. In this encoder–decoder architecture, modules learn in unison to respectively encode the large input data to a relevant latent space and to utilize the latent representation for producing accurate predictions.
This architecture is able to process a large amount of input data while keeping reasonable dimensionality of the training weights. This is especially helpful to improve computation times and reduce the risk of overfitting.
Appendix 1.4: Implicit assumptions
The Transformer architecture allows to make prediction without explicit assumptions on the predicted variable. Its effcience has been shown experimentally in multiple fields of application. However, it features some implicit assumptions that should be stated. Positional encoding effectively conveys temporal information to the model This architecture assumes that the use of sinusoidal functions is efficient for conveying the temporal information to the Transformer encoder. This was not rigorously demonstrated but this method’s effectiveness was empirically observed. Nevertheless, positional encoding could fail to capture some nuances of temporal dependency.
Attention is stationary The attention mechanism does not explicitly compute a temporal variation of the variable informativeness. This can be problematic if a series of a variable contains highly informative values at some times, and noninformative values the rest of the time. However, the initial embedding may isolate such highly informative values and mitigate the limitations caused by this assumption.
Attention as a proxy for relevance Attention as computed by the multihead attention module is based on learning parameters that identify relations between a set of input and output sequences. This concept might not perfectly align with human notion of relevance.
Appendix 2: Additional results
Appendix 2.1:Individual prediction visualization
We presented some individual patients’ predicted hazards on Fig. 5. This figure illustrates the ability of the Transformer model to produce meaningful and individualized predictions, which greatly improves usability in clinical practice. The RCoxPH and PyDTS model offer decent average performance but fail to produce individually accurate hazard estimates.
Appendix 2.2: Peak hazard time prediction
Using the Nonmonotonic hazard event, we designed an experiment to evaluate each model’s ability to create individualized predictions. This is not a standard metric but rather a qualitative insight of models’ performance. The Nonmonotonic hazard event reaches a maximum hazard value between the 3rd and 25th time steps. We compared the time at maximum hazard between the groundtruth and predicted values. The mean absolute error is presented in Table 4. We observed that the Transformer model achieves a much better performance, highlighting its ability to produce a meaning temporal prediction for each patient rather than predictions that are only good on average.
Appendix 3: Supplementary tables
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article's Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article's Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/. The Creative Commons Public Domain Dedication waiver (http://creativecommons.org/publicdomain/zero/1.0/) applies to the data made available in this article, unless otherwise stated in a credit line to the data.
About this article
Cite this article
Oliver, M., Allou, N., Devineau, M. et al. A transformer model for causespecific hazard prediction. BMC Bioinformatics 25, 175 (2024). https://0doiorg.brum.beds.ac.uk/10.1186/s12859024057992
Received:
Accepted:
Published:
DOI: https://0doiorg.brum.beds.ac.uk/10.1186/s12859024057992