via PharmCat on Github (this link has similar code with comments in Russian and English).
begin
using LsqFit
using MortalityTables
using Plots
using Distributions
using Optim
using DataFrames
using PlutoUI; TableOfContents()
end
Sample data:
data = let
survival = [0.99,0.98,0.95,0.9,0.8,0.65,0.5,0.38,0.25,0.2,0.1,0.05,0.02,0.01]
times = 1:length(survival)
DataFrame(;times,survival)
end
times | survival | |
---|---|---|
1 | 1 | 0.99 |
2 | 2 | 0.98 |
3 | 3 | 0.95 |
4 | 4 | 0.9 |
5 | 5 | 0.8 |
6 | 6 | 0.65 |
7 | 7 | 0.5 |
8 | 8 | 0.38 |
9 | 9 | 0.25 |
10 | 10 | 0.2 |
11 | 11 | 0.1 |
12 | 12 | 0.05 |
13 | 13 | 0.02 |
14 | 14 | 0.01 |
plt = plot(data.times,data.survival,label="observed survival proportion",xlabel="time")
Define the two-parameter Weibull model:
x
: array of independent variables
p
: array of model parameters
model(x, p) will accept the full data set as the first argument x
. This means that we need to write our model function so it applies the model to the full dataset. We use @.
to apply ("broadcast") the calculations across all rows.
@. model1(x, p) = survival(MortalityTables.Weibull(;m = p[1],σ = p[2]), x)
model1 (generic function with 1 method)
And fit the model with LsqFit.jl:
fit1 = curve_fit(model1, data.times, data.survival, [1.0, 1.0])
LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Float64}}([8.147278225549185, 2.8507079975810274], [0.007512289530956395, 0.0021036645286324207, -0.0059128358266353764, -0.02271535116992418, -0.019432525084171814, 0.008926964702135387, 0.02305783722832544, 0.007050778443488892, 0.01472116019643549, -0.03405121675050801, -0.005432126302967605, -0.0014077035272865995, 0.0023337166510039725, -0.0008907842989644022], [0.0026998722509579453 -0.005225236784169305; 0.014959447652454607 -0.024973364570363275; … ; 0.015866902563881607 0.03977450356606966; 0.00688563201685319 0.02322957839976213], true, Float64[])
plot!(plt,data.times, model1(data.times, fit1.param),label="fitted model")
Generate 100 sample datapoints:
t = rand(Weibull(fit1.param[2], fit1.param[1]), 100)
100-element Vector{Float64}: 9.702067663069327 3.0428689574368324 9.059639240991595 10.970739445634079 3.9040404619667113 4.165794231001152 4.842637489172447 ⋮ 5.6539647956004515 6.139331596615337 9.641578582932453 4.896864716185817 4.434432566682164 11.444165198474122
#No censored data
fit_mle(Weibull, t)
Distributions.Weibull{Float64}(α=2.852156865689702, θ=8.150334609678797)
c = collect(trues(100))
100-element Vector{Bool}: 1 1 1 1 1 1 1 ⋮ 1 1 1 1 1 1
c[[1,3,7,9]] .= false
4-element view(::Vector{Bool}, [1, 3, 7, 9]) with eltype Bool: 0 0 0 0
#ML function
survmle(x) = begin
ml = 0.0
for i = 1:length(t)
if c[i]
ml += logpdf(Weibull(x[2], x[1]), t[i]) #if not censored log(f(x))
else
ml += logccdf(Weibull(x[2], x[1]), t[i]) #if censored log(1-F)
end
end
-ml
end
survmle (generic function with 1 method)
opt = Optim.optimize(
survmle, # function to optimize
[1.0,1.0], # lower bound
[15.,15.], # upper bound
[3.,3.] # initial guess
)
* Status: success * Candidate solution Final objective value: 2.379075e+02 * Found with Algorithm: Fminbox with L-BFGS * Convergence measures |x - x'| = 9.46e-08 ≰ 0.0e+00 |x - x'|/|x'| = 1.08e-08 ≰ 0.0e+00 |f(x) - f(x')| = 0.00e+00 ≤ 0.0e+00 |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00 |g(x)| = 5.75e-09 ≤ 1.0e-08 * Work counters Seconds run: 1 (vs limit Inf) Iterations: 4 f(x) calls: 50 ∇f(x) calls: 50
The solution converges to similar values as the function generating the synthetic data:
Optim.minimizer(opt)
2-element Vector{Float64}: 8.26940849837394 2.8564783058448087
KaplanMeier comes from Survival.jl
#4
#Подгонка модели по эмпирической функции KM
#Fitting for survival function from Kaplan Meier
using Survival
#t- time vector;c - censored events vector
km = fit(Survival.KaplanMeier, t, c)
KaplanMeier{Float64}([1.5497716226250795, 1.6775379881561192, 1.7085009082750642, 2.5181132678997726, 2.68032348518394, 2.964498811102057, 3.0428689574368324, 3.1715173900300746, 3.3408214242675918, 3.558112309435965 … 10.933124905309146, 10.968571881614759, 10.970739445634079, 11.237243842252303, 11.343262031763562, 11.444165198474122, 11.53918544474052, 12.139001873035417, 12.760811345926422, 13.990010616742882], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [100, 99, 98, 97, 96, 95, 94, 93, 92, 91 … 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], [0.99, 0.98, 0.97, 0.96, 0.95, 0.94, 0.9299999999999999, 0.9199999999999999, 0.9099999999999999, 0.8999999999999999 … 0.1000085820703387, 0.08889651739585662, 0.07778445272137455, 0.06667238804689248, 0.055560323372410396, 0.04444825869792832, 0.03333619402344624, 0.022224129348964162, 0.011112064674482081, 0.011112064674482081], [0.01005037815259212, 0.014285714285714285, 0.017586311452816476, 0.020412414523193152, 0.022941573387056175, 0.025264557631995567, 0.02743516305843672, 0.029488391230979426, 0.03144854510165755, 0.03333333333333333 … 0.3122238849169251, 0.33372540089346836, 0.35948544624038536, 0.39120238990600864, 0.4317089797554669, 0.4861816977237071, 0.5654254827427213, 0.6974042179407192, 0.9931629489673415, 0.9931629489673415])
plt2 = plot(km.times, km.survival; labels="Empirical")
@. model(x, p) = survival(MortalityTables.Weibull(;m = p[1],σ = p[2]), x)
model (generic function with 1 method)
mfit = LsqFit.curve_fit(model, km.times, km.survival, [2.0, 2.0])
LsqFit.LsqFitResult{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Float64}}([8.378620456975263, 3.2256169578710185], [-0.002403778609877949, 0.0047839066859725055, 0.014049535027020887, -0.0030835485933451423, -0.0004737254028512883, -0.005074670900177991, 0.0005240823447953957, 0.0029460936677042637, 0.002303863238325876, -0.0024660638160870496 … 0.03584410203827022, 0.044685608616582315, 0.05565969008761626, 0.05055522784081323, 0.05562383494771387, 0.06119548794504302, 0.0672761069565029, 0.050610626595534725, 0.039553131229845835, 0.011545102549992596], [0.010270434299658012 -0.01675139926361279; 0.01221021915932723 -0.019556832019687578; … ; 0.027138545352528534 0.05119262536384386; 0.012964385596268838 0.035424971300125514], true, Float64[])
plot!(plt2,km.times, model(km.times, mfit.param), labels="Theoretical")
Built with Julia 1.8.5 and
DataFrames 1.3.2To run this page locally, download this file and open it with Pluto.jl.