Parameter Estimation on Highly Stiff Systems

This tutorial goes into training a model on stiff chemical reaction system data.

Copy-Pasteable Code

Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process:

using DifferentialEquations, DiffEqFlux, Optimization, OptimizationOptimJL, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2
    nothing
end

p = [0.04,3e7,1e4]
u0 = [1.0,0.0,0.0]
prob = ODEProblem(rober,u0,(0.0,1e5),p)
sol = solve(prob,Rosenbrock23())
ts = sol.t
Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(rober, 0.0, p), u), sol.u)

function predict_adjoint(p)
    p = exp.(p)
    _prob = remake(prob,p=p)
    Array(solve(_prob,Rosenbrock23(autodiff=false),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    prediction = [prediction[:, i] for i in axes(prediction, 2)]
    diff = map((J,u,data) -> J * (abs2.(u .- data)) , Js, prediction, sol.u)
    loss = sum(abs, sum(diff)) |> sqrt
    loss, prediction
end

callback = function (p,l,pred) #callback function to observe training
    println("Loss: $l")
    println("Parameters: $(exp.(p))")
    # using `remake` to re-create our `prob` with current parameters `p`
    plot(solve(remake(prob, p=exp.(p)), Rosenbrock23())) |> display
    return false # Tell it to not halt the optimization. If return true, then optimization stops
end

initp = ones(3)
# Display the ODE with the initial parameter values.
callback(initp,loss_adjoint(initp)...)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss_adjoint(x), adtype)
optprob = Optimization.OptimizationProblem(optf, initp)

res = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 300)

optprob2 = Optimization.OptimizationProblem(optf, res.u)

res2 = Optimization.solve(optprob2, BFGS(), callback = callback, maxiters = 30, allow_f_increases=true)
println("Ground truth: $(p)\nFinal parameters: $(round.(exp.(res2.u), sigdigits=5))\nError: $(round(norm(exp.(res2.u) - p) ./ norm(p) .* 100, sigdigits=3))%")
Loss: 20.019492983370657
Parameters: [2.718281828459045, 2.718281828459045, 2.718281828459045]
Loss: 20.019492983370657
Parameters: [2.718281828459045, 2.718281828459045, 2.718281828459045]
Loss: 19.90315002541611
Parameters: [2.6912344724089694, 2.745601014972752, 2.7456010147114345]
Loss: 19.787180920993443
Parameters: [2.664458191896055, 2.773190921711659, 2.7731985847158755]
Loss: 19.671593798981196
Parameters: [2.637950416896886, 2.8010516033529504, 2.8010806089897766]
Loss: 19.55639825771811
Parameters: [2.6117100890732035, 2.8291829715678714, 2.8292522166078196]
Loss: 19.441598065085405
Parameters: [2.585734634182759, 2.8575848946564015, 2.8577196842139494]
Loss: 19.327203234107465
Parameters: [2.560023133369768, 2.88625712287732, 2.8864884500598285]
Loss: 19.213152787627106
Parameters: [2.534573471111687, 2.915199354024441, 2.915564237446175]
Loss: 19.09959831940738
Parameters: [2.5093841104221086, 2.9444111528447157, 2.944952570138783]
Loss: 18.98646610598089
Parameters: [2.4844521750815205, 2.973892084347753, 2.974660067613459]
Loss: 18.87376515285862
Parameters: [2.459776104318673, 3.0036415904896936, 3.0046925395405406]
Loss: 18.761502310795613
Parameters: [2.4353545255851583, 3.0336590796731744, 3.03505664305774]
Loss: 18.649675868156912
Parameters: [2.411185916399063, 3.0639438602375946, 3.0657581312620126]
Loss: 18.53818062430806
Parameters: [2.3872682292306404, 3.0944952181972867, 3.0968044516690236]
Loss: 18.42706017766982
Parameters: [2.3636000361904066, 3.125312312491778, 3.128199033238619]
Loss: 18.316499321174614
Parameters: [2.340180021774717, 3.1563942745675715, 3.1599480873409305]
Loss: 18.20650362624858
Parameters: [2.3170057929755563, 3.1877402058689253, 3.1920577867683138]
Loss: 18.09730991864795
Parameters: [2.2940748427449926, 3.219349257458693, 3.224537745561573]
Loss: 17.988310851509667
Parameters: [2.2713853370250634, 3.251220489131191, 3.2573931251578916]
Loss: 17.87981650989609
Parameters: [2.24893618404126, 3.2833528296505112, 3.2906276766407734]
Loss: 17.771824907826247
Parameters: [2.2267256484135203, 3.315745282246516, 3.3242495022359235]
Loss: 17.664348272361902
Parameters: [2.204751277770579, 3.348396754918393, 3.358263097585609]
Loss: 17.55738876695718
Parameters: [2.183012095090076, 3.381306133603642, 3.3926743212559543]
Loss: 17.450947559710055
Parameters: [2.161505663981541, 3.4144723350231985, 3.4274897926045345]
Loss: 17.345029100669624
Parameters: [2.140230002894125, 3.447894250766176, 3.4627159516921275]
Loss: 17.23963954245196
Parameters: [2.1191832510160977, 3.4815707211491014, 3.498358701414933]
Loss: 17.134783307207186
Parameters: [2.098363144611047, 3.5155006363669314, 3.5344260782281847]
Loss: 17.030466246625902
Parameters: [2.077768047834342, 3.5496827608588304, 3.57092215550984]
Loss: 16.92669083069808
Parameters: [2.0573958937799652, 3.584115932713638, 3.6078538439549797]
Loss: 16.823459001531244
Parameters: [2.0372447131132527, 3.6187989709420187, 3.645227166483117]
Loss: 16.72077179544769
Parameters: [2.017312415028714, 3.653730701205396, 3.683049060252042]
Loss: 16.6185599780852
Parameters: [1.997597663542907, 3.6889098902869377, 3.721325299877417]
Loss: 16.51682942900079
Parameters: [1.9780968086414443, 3.7243355402011584, 3.7600664074632264]
Loss: 16.41565490944937
Parameters: [1.958808628391408, 3.760006397713856, 3.799278359482459]
Loss: 16.315005930128354
Parameters: [1.9397311920493268, 3.795921095550451, 3.8389644374818803]
Loss: 16.214886339688505
Parameters: [1.9208625920809084, 3.8320783980300663, 3.8791307008931244]
Loss: 16.115297147245432
Parameters: [1.9022005743096428, 3.8684771279112686, 3.9197848881803594]
Loss: 16.01623187570238
Parameters: [1.8837430161377344, 3.905116029580285, 3.960932454183892]
Loss: 15.91765283895294
Parameters: [1.8654886839581621, 3.9419938433695774, 4.002580029445616]
Loss: 15.819575618899055
Parameters: [1.8474347694475197, 3.979109347499564, 4.044735529358859]
Loss: 15.721981644577284
Parameters: [1.8295796586182607, 4.016461240713894, 4.087405192959522]
Loss: 15.624860372486875
Parameters: [1.8119206020924268, 4.0540482697776685, 4.130596145005581]
Loss: 15.528902770904851
Parameters: [1.7944559319086044, 4.091869110002282, 4.17431464012562]
Loss: 15.4330093412763
Parameters: [1.7771837873513046, 4.129922712114367, 4.218569913974477]
Loss: 15.337629806965834
Parameters: [1.760102106661354, 4.16820778349225, 4.263368815278715]
Loss: 15.242773514148153
Parameters: [1.7432084929106335, 4.2067231165990435, 4.308719754103851]
Loss: 15.148480234734985
Parameters: [1.726501062489739, 4.245467371684403, 4.354629120436351]
Loss: 15.054733304767929
Parameters: [1.7099776561509186, 4.284439072562956, 4.40109996109024]
Loss: 14.96153785881026
Parameters: [1.6936367111938087, 4.323636841015185, 4.448138989320293]
Loss: 14.86892731722697
Parameters: [1.6774762114826283, 4.363059362011158, 4.495753287916786]
Loss: 14.776873046402105
Parameters: [1.6614938525156497, 4.402705329339499, 4.543950688278037]
Loss: 14.685373694983573
Parameters: [1.6456879599269423, 4.442573437446767, 4.592738399018362]
Loss: 14.594434279922375
Parameters: [1.6300561848327084, 4.482662364365767, 4.642124738632269]
Loss: 14.504051138318731
Parameters: [1.6145967956825147, 4.522970739732845, 4.692116470516358]
Loss: 14.414228955780409
Parameters: [1.5993084566569513, 4.563497181858345, 4.742720682963521]
Loss: 14.324937497454941
Parameters: [1.5841885304702805, 4.6042403592346, 4.793945532464226]
Loss: 14.23614249304107
Parameters: [1.569236013939807, 4.64519884369822, 4.845797377054964]
Loss: 14.147871782091128
Parameters: [1.5544489520546738, 4.68637120739578, 4.898283550039591]
Loss: 14.060121286913258
Parameters: [1.539825272479778, 4.727756126580831, 4.951412326401702]
Loss: 13.972889870547444
Parameters: [1.525363243982656, 4.769352173147932, 5.0051913001311155]
Loss: 13.886177433798107
Parameters: [1.5110608529186225, 4.811157934472637, 5.059628475956489]
Loss: 13.799957629513209
Parameters: [1.4969161419947068, 4.853171958614849, 5.114731703379404]
Loss: 13.714291344548185
Parameters: [1.482927532249957, 4.89539261268819, 5.170508589137072]
Loss: 13.62913884313245
Parameters: [1.4690927733397143, 4.937818443638356, 5.226966538135716]
Loss: 13.544500011043851
Parameters: [1.4554097945055084, 4.980448038937834, 5.284115220640061]
Loss: 13.460372087142304
Parameters: [1.4418771522106455, 5.023279880006252, 5.341963832985225]
Loss: 13.376754202219455
Parameters: [1.4284925259551648, 5.066312421358701, 5.400521153164713]
Loss: 13.293633083245567
Parameters: [1.4152548358336101, 5.109544000325792, 5.459794423170338]
Loss: 13.21098744083319
Parameters: [1.4021628450107797, 5.152972841779327, 5.5197890182436025]
Loss: 13.128917119246331
Parameters: [1.3892145101575828, 5.196597367670403, 5.580514543837237]
Loss: 13.047250229011071
Parameters: [1.376407839976792, 5.240416162226818, 5.641982939912488]
Loss: 12.966059369826892
Parameters: [1.3637414225766604, 5.284427342351011, 5.704198666232662]
Loss: 12.885336704151937
Parameters: [1.3512135871682283, 5.328629229117066, 5.76716974861108]
Loss: 12.805067524905578
Parameters: [1.3388228090254104, 5.373019990486231, 5.830906101764407]
Loss: 12.725323582802245
Parameters: [1.3265675893117863, 5.417597818534001, 5.895414675533802]
Loss: 12.645986899912346
Parameters: [1.314446102483137, 5.462361165033986, 5.960706817070023]
Loss: 12.567084764400413
Parameters: [1.3024564431382484, 5.507308265205863, 6.02679190750154]
Loss: 12.488613980334463
Parameters: [1.290596738919421, 5.552437451403574, 6.093683007562612]
Loss: 12.410613866611042
Parameters: [1.2788667916406895, 5.597746552054783, 6.161385005302989]
Loss: 12.33313375212993
Parameters: [1.267264630575415, 5.643233725872588, 6.229907309128174]
Loss: 12.256706615680134
Parameters: [1.2557890856596545, 5.688897085827415, 6.2992574276452835]
Loss: 12.179819638697555
Parameters: [1.2444374657568886, 5.734735329708822, 6.369447277937188]
Loss: 12.103957369654218
Parameters: [1.233209279649152, 5.780746104845252, 6.440481972643447]
Loss: 12.02858612835741
Parameters: [1.2221025580377447, 5.826927665161334, 6.51237193893736]
Loss: 11.95367989899111
Parameters: [1.211115317851567, 5.873278675077116, 6.5851328764641375]
Loss: 11.879218350430866
Parameters: [1.2002469040377406, 5.919796968691842, 6.658770844779082]
Loss: 11.805110665175105
Parameters: [1.1894961766700298, 5.9664809562057695, 6.7332977969970225]
Loss: 11.731522676498573
Parameters: [1.1788613787436426, 6.013328389655964, 6.808717227638169]
Loss: 11.658381363995975
Parameters: [1.1683417190674958, 6.0603371957143155, 6.885038750739553]
Loss: 11.585681742125496
Parameters: [1.157935301715667, 6.107505817451146, 6.96227610571564]
Loss: 11.513415012438948
Parameters: [1.1476406118996, 6.154832620149468, 7.040442545416307]
Loss: 11.441582186023721
Parameters: [1.137456505817011, 6.202315457810469, 7.119545744784013]
Loss: 11.36997091056959
Parameters: [1.127381949981539, 6.249952214142891, 7.199595725099163]
Loss: 11.299059398668529
Parameters: [1.1174153413060723, 6.297740770471291, 7.280601174889227]
Loss: 11.228519118586991
Parameters: [1.1075553268043923, 6.345679349194997, 7.362574450597469]
Loss: 11.158295934982666
Parameters: [1.0978006210346702, 6.393765742643561, 7.4455258228278485]
Loss: 11.088530456040669
Parameters: [1.088150474944357, 6.441997944427121, 7.529463966917438]
Loss: 11.019220887531258
Parameters: [1.0786030527631445, 6.490373926646552, 7.61440180696316]
Loss: 10.950439034413035
Parameters: [1.0691569158461642, 6.53889179245629, 7.700352016347739]
Loss: 10.881923258043077
Parameters: [1.0598112164092082, 6.587549286762915, 7.787321654630769]
Loss: 10.813834470553287
Parameters: [1.0505639274818648, 6.636344803186215, 7.875330003190862]
Loss: 10.746181067218988
Parameters: [1.041415100973985, 6.685275564979392, 7.9643789148386945]
Loss: 10.678952867811796
Parameters: [1.032363101225505, 6.734339633956297, 8.054480649464802]
Loss: 10.612117780946454
Parameters: [1.0234067497262782, 6.783534874388832, 8.14564727810146]
Loss: 10.545690917623494
Parameters: [1.014544773785258, 6.832859410220633, 8.237892641888887]
Loss: 10.479658468307784
Parameters: [1.0057765803906242, 6.882310988946777, 8.331225432621324]
Loss: 10.413943121546202
Parameters: [0.9971010521363416, 6.93188701329201, 8.425653930931189]
Loss: 10.348745030101265
Parameters: [0.9885163784273351, 6.981585734939691, 8.52119540664061]
Loss: 10.283857523465816
Parameters: [0.980021828291255, 7.031405295484188, 8.617859749786907]
Loss: 10.219355319132589
Parameters: [0.9716164212105295, 7.081343485613167, 8.715656669127942]
Loss: 10.155243325938445
Parameters: [0.9632988753496216, 7.1313980866416555, 8.814596715327479]
Loss: 10.091526877948098
Parameters: [0.9550681361094824, 7.181566938505122, 8.914691666494537]
Loss: 10.028201468086479
Parameters: [0.9469231950560224, 7.2318480449458145, 9.015953626064853]
Loss: 9.965258860139619
Parameters: [0.9388638747501393, 7.282239001971623, 9.118390318311594]
Loss: 9.90268772874647
Parameters: [0.9308885120566597, 7.332737878006055, 9.222015418902684]
Loss: 9.84048389333917
Parameters: [0.9229963211350817, 7.383342569166833, 9.326842250996629]
Loss: 9.778580259866072
Parameters: [0.915186348694587, 7.434050915539616, 9.43288332334413]
Loss: 9.717161833505171
Parameters: [0.9074575119535314, 7.484860664470325, 9.540148751698608]
Loss: 9.656088819280393
Parameters: [0.8998092339802296, 7.535769640756616, 9.648646412369443]
Loss: 9.595361892773928
Parameters: [0.8922400743621874, 7.586776330945246, 9.758396892819826]
Loss: 9.535004698490882
Parameters: [0.8847498936670604, 7.637877901589872, 9.86939949898021]
Loss: 9.474992137657328
Parameters: [0.8773376884357962, 7.689072396732624, 9.981668463784874]
Loss: 9.415327625685672
Parameters: [0.870001849270989, 7.740358389945903, 10.095226864268046]
Loss: 9.356021816228376
Parameters: [0.862742432294229, 7.791733092841937, 10.210075317010075]
Loss: 9.296978936029683
Parameters: [0.8555574080093581, 7.8431949674712635, 10.32624026881863]
Loss: 9.238413400185063
Parameters: [0.8484468516827504, 7.894741600618759, 10.443723164754005]
Loss: 9.180187441907824
Parameters: [0.8414092396173308, 7.9463715837901185, 10.562544779587514]
Loss: 9.122293996854797
Parameters: [0.8344442644011726, 7.99808229149063, 10.682705849729498]
Loss: 9.064709863597411
Parameters: [0.8275504976098109, 8.049871590125774, 10.804225433048698]
Loss: 9.00744741383197
Parameters: [0.8207275872190217, 8.101737658857473, 10.927112479270402]
Loss: 8.95052959686907
Parameters: [0.8139741278714403, 8.153679208261037, 11.05138819162484]
Loss: 8.894149465094497
Parameters: [0.8072893649942418, 8.205693615079783, 11.17705702436504]
Loss: 8.837948026098257
Parameters: [0.8006727103468497, 8.257779256332823, 11.304132840454981]
Loss: 8.78207508012679
Parameters: [0.7941238134339822, 8.309933868564206, 11.432624604348222]
Loss: 8.726495920381472
Parameters: [0.7876417811781761, 8.362155339594525, 11.562544443634271]
Loss: 8.671234826629128
Parameters: [0.781225628309691, 8.414441985277671, 11.693909689020568]
Loss: 8.616280804225175
Parameters: [0.7748745953505743, 8.46679156630102, 11.826728112092077]
Loss: 8.561710398499496
Parameters: [0.768588238273827, 8.519202032875889, 11.961008975452867]
Loss: 8.507418213012897
Parameters: [0.7623657405601361, 8.57167209179171, 12.096771790728905]
Loss: 8.45342079520076
Parameters: [0.7562060415640968, 8.624199944412991, 12.234033890020378]
Loss: 8.39972613438068
Parameters: [0.7501084089000769, 8.676783436292286, 12.37280661085551]
Loss: 8.346575453253013
Parameters: [0.7440723807586529, 8.729420567978265, 12.513100589648902]
Loss: 8.29359227850039
Parameters: [0.7380966558471852, 8.78211022255427, 12.654940464961525]
Loss: 8.240916817602727
Parameters: [0.7321812702186886, 8.83484993032213, 12.798321378062736]
Loss: 8.188539667900056
Parameters: [0.726325424725619, 8.887638507035598, 12.94326455393437]
Loss: 8.136444367125213
Parameters: [0.7205284121271339, 8.940473974087155, 13.08978073984502]
Loss: 8.08473097113707
Parameters: [0.714789093655172, 8.993355127415418, 13.237893841352859]
Loss: 8.033317051470544
Parameters: [0.709106897517479, 9.046280214156901, 13.387613792373685]
Loss: 7.982180339906276
Parameters: [0.7034815415946184, 9.099246881902783, 13.538939455044199]
Loss: 7.9313305052808865
Parameters: [0.6979124964026759, 9.152253671321203, 13.691886153891824]
Loss: 7.880774746389338
Parameters: [0.6923991693821454, 9.205298836130662, 13.846464300581038]
Loss: 7.830424228803688
Parameters: [0.6869406520854628, 9.25838110242008, 14.002692891486696]
Loss: 7.780501620909384
Parameters: [0.6815363353946549, 9.311498882063303, 14.160586356944261]
Loss: 7.730867803637811
Parameters: [0.6761856224842812, 9.364650580155187, 14.320155453142624]
Loss: 7.68150114767776
Parameters: [0.6708880969624459, 9.417835193385368, 14.481414778301064]
Loss: 7.632376750613567
Parameters: [0.6656432725547938, 9.471051174506297, 14.644371797915758]
Loss: 7.583593651571574
Parameters: [0.6604504103002966, 9.52429645252602, 14.809034608238775]
Loss: 7.535083445662546
Parameters: [0.655308913573619, 9.57757044436105, 14.975427302266764]
Loss: 7.486856807913221
Parameters: [0.6502183774606627, 9.630871747498299, 15.143556259428266]
Loss: 7.4389141142218485
Parameters: [0.6451780096959328, 9.684198786135493, 15.313435593120573]
Loss: 7.3912088512786704
Parameters: [0.6401872876464141, 9.737550133673697, 15.485076230426667]
Loss: 7.343846459277048
Parameters: [0.6352457646833177, 9.790924201761673, 15.658485087737413]
Loss: 7.296757730134354
Parameters: [0.6303528964315344, 9.844319989155533, 15.833675980106703]
Loss: 7.24993506939905
Parameters: [0.6255085140095072, 9.8977362875409, 16.010635420758238]
Loss: 7.203378149348525
Parameters: [0.6207115933909492, 9.95117226065199, 16.189409431666444]
Loss: 7.1570981706242085
Parameters: [0.6159620670241415, 10.00462639727764, 16.369978845727438]
Loss: 7.111092061884855
Parameters: [0.6112592609733093, 10.0580978461511, 16.55236368166785]
Loss: 7.06548962338665
Parameters: [0.6066026140634685, 10.11158568547673, 16.736579828313584]
Loss: 7.02006462025733
Parameters: [0.6019914882448681, 10.165089296044485, 16.922647363357807]
Loss: 6.9749097025869595
Parameters: [0.5974255380365149, 10.218607288793379, 17.11057139452625]
Loss: 6.930020758575018
Parameters: [0.5929041709761961, 10.272139052478645, 17.300369986415106]
Loss: 6.8854045321434905
Parameters: [0.5884271404749079, 10.3256829718598, 17.492040254377624]
Loss: 6.841056624215328
Parameters: [0.5839938580080148, 10.37923800584795, 17.685595617110852]
Loss: 6.796969542564386
Parameters: [0.5796039353370589, 10.432803979276846, 17.881051707744977]
Loss: 6.753140623724097
Parameters: [0.5752571774448683, 10.486379578835725, 18.07840645731833]
Loss: 6.709600071686999
Parameters: [0.5709530884765048, 10.539963961442353, 18.277670817945367]
Loss: 6.666406245506042
Parameters: [0.5666911909185169, 10.593556376821022, 18.478856292144]
Loss: 6.623418991069467
Parameters: [0.5624707045703801, 10.647156228523203, 18.681982677280047]
Loss: 6.5806976311715655
Parameters: [0.5582912293676795, 10.70076288629471, 18.887058212924202]
Loss: 6.5382380074984034
Parameters: [0.5541523199529487, 10.754375719672744, 19.09409451424071]
Loss: 6.4960315504517645
Parameters: [0.550053394385895, 10.807994585198465, 19.303113652310792]
Loss: 6.45406252138605
Parameters: [0.5459940918681468, 10.861618373512906, 19.514115846725694]
Loss: 6.412356087386694
Parameters: [0.5419741146548963, 10.915246650615844, 19.72710561762164]
Loss: 6.370930089332389
Parameters: [0.5379930587936935, 10.968879379535514, 19.942097836396965]
Loss: 6.329750983111197
Parameters: [0.5340504398689306, 11.022515876814808, 20.15910091629259]
Loss: 6.288835492786458
Parameters: [0.5301458778447297, 11.076155747108775, 20.378121951260418]
Loss: 6.248175590529378
Parameters: [0.5262790668744983, 11.129798670093624, 20.59916492284864]
Loss: 6.207752719149821
Parameters: [0.5224494972656596, 11.183444283924217, 20.822244325951267]
Loss: 6.167526954629035
Parameters: [0.5186567509495748, 11.237092516981265, 21.04737209713465]
Loss: 6.127638665212974
Parameters: [0.5149006358164037, 11.290743235522818, 21.274548679451442]
Loss: 6.087993433596339
Parameters: [0.51118078830325, 11.344396359987366, 21.50377979804483]
Loss: 6.048596833111658
Parameters: [0.507496853474409, 11.398051485516653, 21.735065589994267]
Loss: 6.009428349636511
Parameters: [0.5038483034288479, 11.451708516714405, 21.9684249882131]
Loss: 5.970523427653337
Parameters: [0.5002347690443442, 11.505367628639817, 22.203862897235414]
Loss: 5.93186524995508
Parameters: [0.4966560417904493, 11.559028491830002, 22.441373940728734]
Loss: 5.893452499467949
Parameters: [0.49311172668908765, 11.612691640442243, 22.68097082973004]
Loss: 5.855285135975657
Parameters: [0.48960143204192824, 11.666357420463003, 22.92266489985505]
Loss: 5.817333221822839
Parameters: [0.48612494563072833, 11.720025904705702, 23.166453527655268]
Loss: 5.77966954654294
Parameters: [0.4826818618892611, 11.773696884098726, 23.41234053364903]
Loss: 5.7422530937316205
Parameters: [0.4792718026327982, 11.827370537560537, 23.66033379164273]
Loss: 5.705068082581953
Parameters: [0.4758944874276727, 11.881047072110233, 23.91043059230487]
Loss: 5.668124175372164
Parameters: [0.472549632581047, 11.934726932944198, 24.162633686352876]
Loss: 5.631431469176561
Parameters: [0.46923687728822067, 11.98841052326022, 24.41695037989741]
Loss: 5.594989399505453
Parameters: [0.4659559622239603, 12.042097978422039, 24.673376456246164]
Loss: 5.558851316608842
Parameters: [0.4627065617987129, 12.095789929227562, 24.93191703527296]
Loss: 5.522903422251171
Parameters: [0.45948842062040607, 12.14948702855233, 25.192568872405694]
Loss: 5.487201368862247
Parameters: [0.45630110459823275, 12.203189592552475, 25.45534079323558]
Loss: 5.451743145491786
Parameters: [0.4531443958037228, 12.256898440248833, 25.720230428289902]
Loss: 5.4165328505852015
Parameters: [0.4500179610708314, 12.310614108092281, 25.98723953088206]
Loss: 5.381573672187911
Parameters: [0.4469214772070607, 12.364337277982273, 26.2563718112695]
Loss: 5.346860076382596
Parameters: [0.4438546361932309, 12.418068624038868, 26.527632494839718]
Loss: 5.312393884971241
Parameters: [0.4408172190869948, 12.471808829361466, 26.801011339383443]
Loss: 5.278158003708906
Parameters: [0.43780894714353313, 12.525558673314071, 27.076504828949613]
Loss: 5.244156294164429
Parameters: [0.4348295624942019, 12.579318984954844, 27.354111533999948]
Loss: 5.21037721237439
Parameters: [0.431878834603504, 12.633090428474473, 27.633819790889234]
Loss: 5.177066092229786
Parameters: [0.42895648262234326, 12.686874022382028, 27.915628750194585]
Loss: 5.143862324590605
Parameters: [0.4260622437178234, 12.740671013709829, 28.19953296860961]
Loss: 5.110911105743173
Parameters: [0.4231965462737636, 12.794482561766285, 28.485410987124972]
Loss: 5.078210711286674
Parameters: [0.42035901474211884, 12.848309827110578, 28.773270429860865]
Loss: 5.045745911422672
Parameters: [0.4175493198919059, 12.902154097264745, 29.06311401657367]
Loss: 5.013507539684101
Parameters: [0.41476716906318783, 12.95601618372516, 29.354934293317804]
Loss: 4.981513941114086
Parameters: [0.41201224829381955, 13.009897490229513, 29.6487322418506]
Loss: 4.949766658915155
Parameters: [0.4092842788592596, 13.063799205269108, 29.944502917316065]
Loss: 4.918266160532087
Parameters: [0.40658301503619315, 13.117722715388494, 30.24223774219106]
Loss: 4.887012391521234
Parameters: [0.4039081657140168, 13.171669288428221, 30.541931563419375]
Loss: 4.856005063421083
Parameters: [0.4012594701425499, 13.225640422440566, 30.84357747851082]
Loss: 4.825242844030765
Parameters: [0.39863665665197784, 13.279637650009278, 31.147168309242705]
Loss: 4.794724609191196
Parameters: [0.396039473359364, 13.333662419504467, 31.45269393942198]
Loss: 4.764374642653869
Parameters: [0.39346770385759233, 13.387716125365731, 31.76013778638729]
Loss: 4.734369146061949
Parameters: [0.3909211045529585, 13.441800288835655, 32.069489451592936]
Loss: 4.704603803116384
Parameters: [0.3883994106730699, 13.495916419769932, 32.3807375062658]
Loss: 4.675073639332765
Parameters: [0.3859024343999048, 13.550066208684349, 32.69386099561259]
Loss: 4.6457835897925595
Parameters: [0.3834299000626485, 13.604251480254119, 33.00885287668782]
Loss: 4.61673102345192
Parameters: [0.38098158645393293, 13.658474128555879, 33.325698576709506]
Loss: 4.587918731976264
Parameters: [0.3785573297564391, 13.712735839664056, 33.644370402358305]
Loss: 4.559346511960054
Parameters: [0.37615690456302997, 13.767038407266334, 33.96485064770102]
Loss: 4.531013239442099
Parameters: [0.3737800224297452, 13.82138395896526, 34.28713733696015]
Loss: 4.502919755799524
Parameters: [0.37142652018063455, 13.87577423273929, 34.611200324920226]
Loss: 4.475017387739235
Parameters: [0.36909620016651873, 13.930211330668815, 34.93701814312791]
Loss: 4.447415477765859
Parameters: [0.36678884021512576, 13.984696803265509, 35.26456790162753]
Loss: 4.42004889129935
Parameters: [0.36450425298530603, 14.039232705997225, 35.59382456816844]
Loss: 4.392908675687397
Parameters: [0.3622422169507634, 14.093821198375077, 35.924768766409755]
Loss: 4.366015467553979
Parameters: [0.3600025949166958, 14.148464094944565, 36.257363351893176]
Loss: 4.339367091919473
Parameters: [0.35778514858583876, 14.203163704655323, 36.59159281275869]
Loss: 4.312961676061755
Parameters: [0.35558970951578467, 14.257922164639274, 36.92742673403962]
Loss: 4.2868728980428905
Parameters: [0.35341611595381045, 14.312741530611827, 37.26483075475771]
Loss: 4.260954329476684
Parameters: [0.3512641861119794, 14.367624129894955, 37.603774180590015]
Loss: 4.235274356192919
Parameters: [0.34913368547561024, 14.422572666687657, 37.94424350193909]
Loss: 4.20984000887568
Parameters: [0.3470244811403468, 14.477588910606285, 38.28619322699548]
Loss: 4.184650356314152
Parameters: [0.34493639891865135, 14.532675198515173, 38.62959056971942]
Loss: 4.159705753562593
Parameters: [0.342869282834241, 14.587833718281642, 38.974395306640474]
Loss: 4.135004491618598
Parameters: [0.34082293676820574, 14.643066942815098, 39.32057874261813]
Loss: 4.110543731708946
Parameters: [0.3387971897173707, 14.69837737608172, 39.668108164774104]
Loss: 4.0863672101328925
Parameters: [0.3367919035024506, 14.753767128317955, 40.01693518412415]
Loss: 4.062393253890233
Parameters: [0.3348069027176508, 14.809238902153913, 40.367028249366165]
Loss: 4.03865595190238
Parameters: [0.3328420419600223, 14.864795102130074, 40.71834243174704]
Loss: 4.015155229847324
Parameters: [0.330897173599946, 14.920438013067676, 41.07083176331101]
Loss: 3.9918945185440378
Parameters: [0.32897210427535944, 14.976170304311092, 41.4244661135829]
Loss: 3.9690550247037013
Parameters: [0.32706667348834634, 15.031994365676818, 41.77920081540613]
Loss: 3.946329297926382
Parameters: [0.32518077261941986, 15.087913195021704, 42.134985518863004]
Loss: 3.9238443546149298
Parameters: [0.32331424447043, 15.14392930395091, 42.49177385766592]
Loss: 3.901599103914558
Parameters: [0.32146692860867515, 15.200045491621571, 42.849523880521176]
Loss: 3.8795933239064864
Parameters: [0.3196387028158868, 15.256264331817748, 43.20817895911524]
Loss: 3.8578268339156274
Parameters: [0.3178294179144646, 15.312588357247707, 43.567689199507875]
Loss: 3.8362991691539308
Parameters: [0.3160388831461629, 15.369020412957488, 43.92801631300151]
Loss: 3.8150103642006816
Parameters: [0.31426698523344804, 15.425563004819146, 44.28910008949504]
Loss: 3.793959962178224
Parameters: [0.3125135954396385, 15.482218729423462, 44.65088414886524]
Loss: 3.773146705176205
Parameters: [0.3107785606057281, 15.538990567698505, 45.01332006003012]
Loss: 3.7525551039522673
Parameters: [0.3090617416300403, 15.595881082946486, 45.37635232100527]
Loss: 3.7321939754688502
Parameters: [0.3073629700453353, 15.652893096582785, 45.73993537762903]
Loss: 3.7120673623048472
Parameters: [0.30568210979583255, 15.710029504748631, 46.10401002205411]
Loss: 3.6921748062126643
Parameters: [0.30401903927604046, 15.767292737282204, 46.46851244970595]
Loss: 3.672515674957042
Parameters: [0.3023736514102005, 15.824685635434813, 46.8333769449452]
Loss: 3.653024841970217
Parameters: [0.3007457928003984, 15.882211113957794, 47.19855211827751]
Loss: 3.633851239429444
Parameters: [0.29913469767301726, 15.93987176987575, 47.56422710631303]
Loss: 3.6148989179661655
Parameters: [0.2975402884700297, 15.997670393814417, 47.9303258273522]
Loss: 3.5961682151619514
Parameters: [0.2959624945860961, 16.05560975235238, 48.296767502484144]
Loss: 3.5776585129037017
Parameters: [0.29440127558470425, 16.1136929634526, 48.66346644238856]
Loss: 3.5593588227762427
Parameters: [0.292856555059788, 16.17192271967512, 49.03034269831407]
Loss: 3.541280798826864
Parameters: [0.2913282579272867, 16.230301519719383, 49.39731052948181]
Loss: 3.523432770969898
Parameters: [0.289816899727113, 16.28883234925083, 49.76403565548221]
Loss: 3.50581323835071
Parameters: [0.28832229686724137, 16.34751820875983, 50.13047313683978]
Loss: 3.4884194664112655
Parameters: [0.2868442592278507, 16.4063616117597, 50.4965718873381]
Loss: 3.4712513005388392
Parameters: [0.2853826628815883, 16.46536557187417, 50.86226753356712]
Loss: 3.4543070993976865
Parameters: [0.28393737813803543, 16.52453272841763, 51.22749074603287]
Loss: 3.4375853050649616
Parameters: [0.2825082194939101, 16.5838658493154, 51.59219244012193]
Loss: 3.4210580951626737
Parameters: [0.281095049772522, 16.64336775826083, 51.95630932054189]
Loss: 3.404780441502655
Parameters: [0.279697715813748, 16.703041256773616, 52.31978204072822]
Loss: 3.388720934765794
Parameters: [0.2783161069760249, 16.762889008485658, 52.68253458152506]
Loss: 3.37286802297139
Parameters: [0.27695010744529003, 16.822913835869965, 53.04449552992676]
Loss: 3.3572245739114255
Parameters: [0.2755995512243939, 16.883118510314528, 53.405608857434046]
Loss: 3.3417923464692536
Parameters: [0.27426432922533234, 16.943505717135068, 53.765799824466896]
Loss: 3.3265687851634462
Parameters: [0.27294426942130356, 17.004078279369836, 54.12501839103035]
Loss: 3.311560387948164
Parameters: [0.2716392611733204, 17.064838967100133, 54.48318944815606]
Loss: 3.2967619428274983
Parameters: [0.27034914691602324, 17.125790490934985, 54.84025601070133]
Loss: 3.282172758071208
Parameters: [0.26907381951012455, 17.186935358667245, 55.196140194137506]
Loss: 3.267882849689706
Parameters: [0.2678131470931389, 17.248276201114475, 55.55077245555077]
Loss: 3.2537030528322033
Parameters: [0.26656700695765734, 17.309816001434854, 55.904088289437176]
Loss: 3.239722758097567
Parameters: [0.26533523847714324, 17.3715575381498, 56.25603617576578]
Loss: 3.225939961621297
Parameters: [0.26411773490354185, 17.43350328735883, 56.60653775199007]
Loss: 3.225939961621297
Parameters: [0.26411773490354185, 17.43350328735883, 56.60653775199007]
Loss: 3.225939961621297
Parameters: [0.26411773490354185, 17.43350328735883, 56.60653775199007]
Loss: 2.2095542861436845
Parameters: [0.10973610854703866, 36.9249428356505, 89.16488225587848]
Loss: 1.1577528937476125
Parameters: [0.09472417223965116, 351.79255671000124, 45.23404164501431]
Loss: 0.7821308008643185
Parameters: [0.05607132980648526, 310.4412292102953, 64.76939629346376]
Loss: 0.48931423185661316
Parameters: [0.046024380465171724, 866.0100514778585, 50.350889939705304]
Loss: 0.34658117072671957
Parameters: [0.02097295699004965, 2482.85134439612, 55.005499494715906]
Loss: 0.22207407637913265
Parameters: [0.022602045650809064, 2816.2405405976538, 52.19923730021022]
Loss: 0.20838513880469026
Parameters: [0.021735771452341113, 4469.692197740546, 59.9599700531125]
Loss: 0.13751091786388084
Parameters: [0.033814106555885, 36071.63809956011, 316.4637626989981]
Loss: 0.055307467720366675
Parameters: [0.03876429862628081, 46171.64116241784, 381.5538578969085]
Loss: 0.04514695951101255
Parameters: [0.04234465699973423, 85840.35002812499, 561.8070944693847]
Loss: 0.02680890519394568
Parameters: [0.03920941824477972, 951951.4820622762, 1774.1224772076932]
Loss: 0.016081384750967893
Parameters: [0.03773314170267682, 1.171265370980733e6, 1862.3296466215968]
Loss: 0.01091394703371504
Parameters: [0.039719424021636814, 1.4052710407162993e6, 2137.337619638896]
Loss: 0.007505213875045712
Parameters: [0.04042219255597009, 2.0514529931639947e6, 2643.2023902365113]
Loss: 0.0046388488917115936
Parameters: [0.039882389918145966, 5.004747995071459e6, 4077.2004876190936]
Loss: 0.0015458179617367496
Parameters: [0.04000261149254969, 1.2565901998217607e7, 6469.913094196683]
Loss: 0.0010353346507460343
Parameters: [0.04019997732014448, 2.4592350814677496e7, 9098.744798679205]
Loss: 0.0007459251496387836
Parameters: [0.04001321681237845, 4.5262581974720165e7, 12279.832180437024]
Loss: 0.0006323025099851694
Parameters: [0.039980041601236846, 3.913546514770303e7, 11412.287958045348]
Loss: 0.0004986649527844117
Parameters: [0.04002875954623146, 3.0494868622615993e7, 10086.134772621812]
Loss: 0.00033483160439435683
Parameters: [0.040020760271880534, 2.8349552064884644e7, 9723.238019415076]
Loss: 0.00017091797437794435
Parameters: [0.03999050661468102, 2.783726724094469e7, 9628.112807688263]
Loss: 0.00017091806418930404
Parameters: [0.03999070602359233, 2.7885288725757807e7, 9636.482763739745]
Loss: 0.0001709176823030663
Parameters: [0.039990706966782065, 2.788550925392201e7, 9636.521198401382]
Loss: 0.00017091773790806102
Parameters: [0.039990706966782065, 2.7885509253922507e7, 9636.521198401468]
Loss: 0.00017091785201744334
Parameters: [0.03999070697489044, 2.788551115654892e7, 9636.521529990678]
Loss: 0.00014923633902541174
Parameters: [0.040001069638812295, 2.7898698360543653e7, 9641.95336185187]
Loss: 0.00014888854710167496
Parameters: [0.040000077397721705, 2.7902091047777608e7, 9642.303359539448]
Loss: 0.00014263001426481785
Parameters: [0.040002055962451724, 2.8001316818717916e7, 9660.172780495499]
Loss: 0.00014262983378802814
Parameters: [0.039994414724460865, 2.8538464347461e7, 9750.651341973593]
Ground truth: [0.04, 3.0e7, 10000.0]
Final parameters: [0.039994, 2.8538e7, 9750.7]
Error: 4.87%

Output:

Ground truth: [0.04, 3.0e7, 10000.0]
Final parameters: [0.040002, 3.0507e7, 10084.0]
Error: 1.69%

Explanation

First, let's get a time series array from the Robertson's equation as data.

using DifferentialEquations, DiffEqFlux, Optimization, OptimizationOptimJL, LinearAlgebra
using ForwardDiff
using DiffEqBase: UJacobianWrapper
using Plots
function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2
    nothing
end

p = [0.04,3e7,1e4]
u0 = [1.0,0.0,0.0]
prob = ODEProblem(rober,u0,(0.0,1e5),p)
sol = solve(prob,Rosenbrock23())
ts = sol.t
Js = map(u->I + 0.1*ForwardDiff.jacobian(UJacobianWrapper(rober, 0.0, p), u), sol.u)
61-element Vector{Matrix{Float64}}:
 [0.996 0.0 0.0; 0.004 1.0 0.0; 0.0 0.0 1.0]
 [0.996 3.9181897521319503e-7 0.0012780900152625978; 0.004 -6.668540483394562 -0.0012780900152625978; 0.0 7.668540091575586 1.0]
 [0.996 4.175663804739006e-5 0.005718510461294757; 0.004 -33.31110452440659 -0.005718510461294757; 0.0 34.31106276776854 1.0]
 [0.996 0.00024992454904057104 0.009992106612572492; 0.004 -58.95288959998399 -0.009992106612572492; 0.0 59.952639675434945 1.0]
 [0.996 0.0016037077316934775 0.017833623941038088; 0.004 -106.00334735396024 -0.017833623941038088; 0.0 107.00174364622853 1.0]
 [0.996 0.004682453587410618 0.02403488562731424; 0.004 -143.21399621747284 -0.02403488562731424; 0.0 144.20931376388543 1.0]
 [0.996 0.01288429926109498 0.030390689334989115; 0.004 -181.3570203091958 -0.030390689334989115; 0.0 182.3441360099347 1.0]
 [0.996 0.02531711709494679 0.03388427339038224; 0.004 -202.3309574593884 -0.03388427339038224; 0.0 203.30564034229346 1.0]
 [0.996 0.046868882246842165 0.03583508669306405; 0.004 -214.0573890406311 -0.03583508669306405; 0.0 215.01052015838428 1.0]
 [0.996 0.077296222064754 0.03641240161925743; 0.004 -217.5517059376093 -0.03641240161925743; 0.0 218.47440971554457 1.0]
 ⋮
 [0.996 944.3646828226424 0.00023546322394505425; 0.004 -944.7774621663126 -0.00023546322394505425; 0.0 1.4127793436703255 1.0]
 [0.996 952.0744466258235 0.00020121496539471083; 0.004 -952.2817364181918 -0.00020121496539471083; 0.0 1.2072897923682648 1.0]
 [0.996 958.7664043966818 0.00017192789116847034; 0.004 -958.7979717436925 -0.00017192789116847034; 0.0 1.031567347010822 1.0]
 [0.996 964.5628529142931 0.0001468836276202215; 0.004 -964.4441546800144 -0.0001468836276202215; 0.0 0.8813017657213289 1.0]
 [0.996 969.574501432246 0.00012546809864160538; 0.004 -969.3273100240957 -0.00012546809864160538; 0.0 0.7528085918496323 1.0]
 [0.996 973.9007597174119 0.00010715608821694972; 0.004 -973.5436962467137 -0.00010715608821694972; 0.0 0.6429365293016983 1.0]
 [0.996 977.6302161336013 9.149845157822547e-5; 0.004 -977.1792068430707 -9.149845157822547e-5; 0.0 0.5489907094693528 1.0]
 [0.996 980.8413583945696 7.811096455346304e-5; 0.004 -980.3100241818904 -7.811096455346304e-5; 0.0 0.4686657873207783 1.0]
 [0.996 982.1720335649052 7.258919980138978e-5; 0.004 -981.6075687637136 -7.258919980138978e-5; 0.0 0.4355351988083387 1.0]

Note that we also computed a shifted and scaled Jacobian along with the solution. We will use this matrix to scale the loss later.

We fit the parameters in log space, so we need to compute exp.(p) to get back the original parameters.

function predict_adjoint(p)
    p = exp.(p)
    _prob = remake(prob,p=p)
    Array(solve(_prob,Rosenbrock23(autodiff=false),saveat=ts,sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss_adjoint(p)
    prediction = predict_adjoint(p)
    prediction = [prediction[:, i] for i in axes(prediction, 2)]
    diff = map((J,u,data) -> J * (abs2.(u .- data)) , Js, prediction, sol.u)
    loss = sum(abs, sum(diff)) |> sqrt
    loss, prediction
end
loss_adjoint (generic function with 1 method)

The difference between the data and the prediction is weighted by the transformed Jacobian to do a relative scaling of the loss.

We define a callback function.

callback = function (p,l,pred) #callback function to observe training
    println("Loss: $l")
    println("Parameters: $(exp.(p))")
    # using `remake` to re-create our `prob` with current parameters `p`
    plot(solve(remake(prob, p=exp.(p)), Rosenbrock23())) |> display
    return false # Tell it to not halt the optimization. If return true, then optimization stops
end
#7 (generic function with 1 method)

We then use a combination of ADAM and BFGS to minimize the loss function to accelerate the optimization. The initial guess of the parameters are chosen to be [1, 1, 1.0].

initp = ones(3)
# Display the ODE with the initial parameter values.
callback(initp,loss_adjoint(initp)...)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss_adjoint(x), adtype)

optprob = Optimization.OptimizationProblem(optf, initp)
res = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 300)

optprob2 = Optimization.OptimizationProblem(optf, res.u)
res2 = Optimization.solve(optprob2, BFGS(), callback = callback, maxiters = 30, allow_f_increases=true)
u: 3-element Vector{Float64}:
 -3.219015466506119
 17.16676336153872
  9.185089366065569

Finally, we can analyze the difference between the fitted parameters and the ground truth.

println("Ground truth: $(p)\nFinal parameters: $(round.(exp.(res2.u), sigdigits=5))\nError: $(round(norm(exp.(res2.u) - p) ./ norm(p) .* 100, sigdigits=3))%")
Ground truth: [0.04, 3.0e7, 10000.0]
Final parameters: [0.039994, 2.8538e7, 9750.7]
Error: 4.87%

It gives the output

Ground truth: [0.04, 3.0e7, 10000.0]
Final parameters: [0.040002, 3.0507e7, 10084.0]
Error: 1.69%