Fitting survival data with MortaltityTables.jl

JuliaActuary is an ecosystem of packages that makes Julia the easiest language to get started for actuarial workflows.
survival
tutorial
begin
    using LsqFit
    using MortalityTables
    using Plots
    using Distributions
    using Optim
    using DataFrames
    using Survival
end

This tutorial is via PharmCat on Github (this link has similar code with comments in Russian and English).

Fitting a Weibull survival curve

Sample data:

data = let
    survival = [0.99, 0.98, 0.95, 0.9, 0.8, 0.65, 0.5, 0.38, 0.25, 0.2, 0.1, 0.05, 0.02, 0.01]
    times = 1:length(survival)
    DataFrame(; times, survival)
end
14×2 DataFrame
Row times survival
Int64 Float64
1 1 0.99
2 2 0.98
3 3 0.95
4 4 0.9
5 5 0.8
6 6 0.65
7 7 0.5
8 8 0.38
9 9 0.25
10 10 0.2
11 11 0.1
12 12 0.05
13 13 0.02
14 14 0.01

Visualizing the data:

plt = plot(data.times, data.survival, label="observed survival proportion", xlabel="time")

Define the two-parameter Weibull model:

  • x: array of independent variables
  • p: array of model parameters

model(x, p) will accept the full data set as the first argument x. This means that we need to write our model function so it applies the model to the full dataset. We use @. to apply (“broadcast”) the calculations across all rows.

@. model1(x, p) = survival(MortalityTables.Weibull(; m=p[1], σ=p[2]), x)
model1 (generic function with 1 method)

Fitting the Model

And fit the model with LsqFit.jl:

fit1 = curve_fit(model1, data.times, data.survival, [1.0, 1.0])

plot!(plt, data.times, model1(data.times, fit1.param), label="fitted model")

Maximum Likelihood estimation

Generate 100 sample datapoints:

t = rand(Weibull(fit1.param[2], fit1.param[1]), 100)
100-element Vector{Float64}:
  8.825787337622568
 16.569618848181783
 12.956834421711289
  6.555400926197694
  8.504194949892344
  8.713067857978032
  6.902907682735063
  2.9238737618887756
  6.220224938590163
  5.637392847808691
  6.566266756942628
  3.4227977121519046
  9.072382858601902
  ⋮
 12.427306670864814
 13.020181642791094
  8.937824930348103
  7.384760880970262
  8.62992521964653
  9.064473129442108
  7.784665870395106
  8.063042056800445
  3.4419321614824834
 10.237415171326523
  6.022490383277525
  7.869905502329432

Without Censored Data”

#No censored data
fit_mle(Weibull, t)
Weibull{Float64}(α=2.624123641855999, θ=8.249520319926305)

With Censored Data

Pick some arbitrary observations to censor:

c = collect(trues(100))
c[[1, 3, 7, 9]] .= false
4-element view(::Vector{Bool}, [1, 3, 7, 9]) with eltype Bool:
 0
 0
 0
 0
#ML function
survmle(x) = begin
    ml = 0.0
    for i = 1:length(t)
        if c[i]
            ml += logpdf(Weibull(x[2], x[1]), t[i]) #if not censored log(f(x))
        else
            ml += logccdf(Weibull(x[2], x[1]), t[i]) #if censored log(1-F)
        end
    end
    -ml
end

opt = Optim.optimize(
    survmle,          # function to optimize
    [1.0, 1.0], # lower bound
    [15.0, 15.0],            # upper bound
    [3.0, 3.0]          # initial guess
)
 * Status: success

 * Candidate solution
    Final objective value:     2.447550e+02

 * Found with
    Algorithm:     Fminbox with L-BFGS

 * Convergence measures
    |x - x'|               = 9.89e-08 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.13e-08 ≰ 0.0e+00
    |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
    |g(x)|                 = 2.72e-09 ≤ 1.0e-08

 * Work counters
    Seconds run:   0  (vs limit Inf)
    Iterations:    4
    f(x) calls:    53
    ∇f(x) calls:   53

The solution converges to similar values as the function generating the synthetic data:

Optim.minimizer(opt)
2-element Vector{Float64}:
 8.36026550865313
 2.5856602163415103

Fitting Kaplan Meier

KaplanMeier comes from Survival.jl.

#t- time vector;c - censored events vector
km = fit(Survival.KaplanMeier, t, c)

plt2 = plot(km.events.time, km.survival; labels="Empirical")
@. model(x, p) = survival(MortalityTables.Weibull(; m=p[1], σ=p[2]), x)

mfit = LsqFit.curve_fit(model, km.events.time, km.survival, [2.0, 2.0])

plot!(plt2, km.events.time, model(km.events.time, mfit.param), labels="Theoretical")