Weather forecasting with neural ODEs

In this example we are going to apply neural ODEs to a multidimensional weather dataset and use it for weather forecasting. This example is adapted from Forecasting the weather with neural ODEs - Sebatian Callh personal blog.

The data

The data is a four-dimensional dataset of daily temperature, humidity, wind speed and pressure measured over four years in the city Delhi. Let us download and plot it.

using Random, Dates, Optimization, ComponentArrays, Lux, OptimizationOptimisers, DiffEqFlux,
      OrdinaryDiffEq, CSV, DataFrames, Dates, Statistics, Plots
using Downloads: download

function download_data(
        data_url = "https://raw.githubusercontent.com/SebastianCallh/neural-ode-weather-forecast/master/data/",
        data_local_path = "./delhi")
    function load(file_name)
        download("$data_url/$file_name", joinpath(data_local_path, file_name))
        return CSV.read(joinpath(data_local_path, file_name), DataFrame)
    end

    mkpath(data_local_path)
    train_df = load("DailyDelhiClimateTrain.csv")
    test_df = load("DailyDelhiClimateTest.csv")
    return vcat(train_df, test_df)
end

df = download_data()
5×5 DataFrame
Rowdatemeantemphumiditywind_speedmeanpressure
DateFloat64Float64Float64Float64
12013-01-0110.084.50.01015.67
22013-01-027.492.02.981017.8
32013-01-037.1666787.04.633331018.67
42013-01-048.6666771.33331.233331017.17
52013-01-056.086.83333.71016.5
FEATURES = [:meantemp, :humidity, :wind_speed, :meanpressure]
UNITS = ["Celsius", "g/m³ of water", "km/h", "hPa"]
FEATURE_NAMES = ["Mean temperature", "Humidity", "Wind speed", "Mean pressure"]

function plot_data(df)
    plots = map(enumerate(zip(FEATURES, FEATURE_NAMES, UNITS))) do (i, (f, n, u))
        plot(df[:, :date], df[:, f]; title = n, label = nothing,
            ylabel = u, size = (800, 600), color = i)
    end

    n = length(plots)
    plot(plots...; layout = (Int(n / 2), Int(n / 2)))
end

plot_data(df)
Example block output

The data show clear annual behaviour (it is difficult to see for pressure due to wild measurement errors but the pattern is there). It is concievable that this system can be described with an ODE, but which? Let us use an network to learn the dynamics from the dataset. Training neural networks is easier with standardised data so we will compute standardised features before training. Finally, we take the first 20 days for training and the rest for testing.

function standardize(x)
    μ = mean(x; dims = 2)
    σ = std(x; dims = 2)
    z = (x .- μ) ./ σ
    return z, μ, σ
end

function featurize(raw_df, num_train = 20)
    raw_df.year = Float64.(year.(raw_df.date))
    raw_df.month = Float64.(month.(raw_df.date))
    df = combine(groupby(raw_df, [:year, :month]),
        :date => (d -> mean(year.(d)) .+ mean(month.(d)) ./ 12),
        :meantemp => mean, :humidity => mean, :wind_speed => mean,
        :meanpressure => mean; renamecols = false)
    t_and_y(df) = df.date', Matrix(select(df, FEATURES))'
    t_train, y_train = t_and_y(df[1:num_train, :])
    t_test, y_test = t_and_y(df[(num_train + 1):end, :])
    t_train, t_mean, t_scale = standardize(t_train)
    y_train, y_mean, y_scale = standardize(y_train)
    t_test = (t_test .- t_mean) ./ t_scale
    y_test = (y_test .- y_mean) ./ y_scale

    return (
        vec(t_train), y_train, vec(t_test), y_test, (t_mean, t_scale), (y_mean, y_scale))
end

function plot_features(t_train, y_train, t_test, y_test)
    plt_split = plot(reshape(t_train, :), y_train'; linewidth = 3, colors = 1:4,
        xlabel = "Normalized time", ylabel = "Normalized values",
        label = nothing, title = "Features")
    plot!(plt_split, reshape(t_test, :), y_test'; linewidth = 3,
        linestyle = :dash, color = [1 2 3 4], label = nothing)
    plot!(plt_split, [0], [0]; linewidth = 0, label = "Train", color = 1)
    plot!(plt_split, [0], [0]; linewidth = 0, linestyle = :dash,
        label = "Test", color = 1, ylims = (-5, 5))
end

t_train, y_train, t_test, y_test, (t_mean, t_scale), (y_mean, y_scale) = featurize(df)
plot_features(t_train, y_train, t_test, y_test)
Example block output

The dataset is now centered around 0 with a standard deviation of 1. We will ignore the extreme pressure measurements for simplicity. Since they are in the test split they won't impact training anyway. We are now ready to construct and train our model! To avoid local minimas we will train iteratively with increasing amounts of data.

function neural_ode(t, data_dim)
    f = Chain(Dense(data_dim => 64, swish), Dense(64 => 32, swish), Dense(32 => data_dim))

    node = NeuralODE(f, extrema(t), Tsit5(); saveat = t, abstol = 1e-6, reltol = 1e-3)

    rng = Xoshiro(0)
    p, state = Lux.setup(rng, f)

    return node, ComponentArray(p), state
end

function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kwargs...)
    predict(p) = Array(node(y0, p, state)[1])
    loss(p) = sum(abs2, predict(p) .- y)

    adtype = Optimization.AutoZygote()
    optf = OptimizationFunction((p, _) -> loss(p), adtype)
    optprob = OptimizationProblem(optf, p)
    res = solve(optprob, opt; maxiters = maxiters, kwargs...)
    res.minimizer, state
end

function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...)
    log_results(ps, losses) = (state, loss) -> begin
        push!(ps, copy(state.u))
        push!(losses, loss)
        false
    end

    ps, losses = ComponentArray[], Float32[]
    for k in obs_grid
        node, p_new, state_new = neural_ode(t, size(y, 1))
        p === nothing && (p = p_new)
        state === nothing && (state = state_new)

        p,
        state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
            maxiters, rng; callback = log_results(ps, losses), kwargs...)
    end
    ps, state, losses
end

rng = MersenneTwister(123)
obs_grid = 4:4:length(t_train) # we train on an increasing amount of the first k obs
maxiters = 150
lr = 5e-3
ps, state, losses = train(t_train, y_train, obs_grid, maxiters, lr, rng; progress = true)
(ComponentArrays.ComponentArray[(layer_1 = (weight = Float32[-0.7645059 0.53793734 0.39337653 0.70219153; -0.07752936 -0.15911278 -0.7899115 -0.54955494; … ; -0.49703807 0.2862404 0.17134601 0.12848014; 0.12155315 -0.079533935 -0.21581413 -0.6378495], bias = Float32[0.17611152, -0.3398568, 0.47946006, 0.0011358261, -0.369452, -0.48125637, -0.14892167, 0.36018282, 0.23786569, 0.15563601  …  0.32014894, -0.34898245, 0.3859067, -0.4593916, -0.33226132, 0.26252872, 0.3789305, 0.20379025, -0.10364276, 4.684925f-5]), layer_2 = (weight = Float32[0.14405736 0.18788297 … 0.09533301 0.20118886; 0.07223731 0.20376958 … 0.1997639 0.17007011; … ; 0.12802728 -0.021456784 … -0.09562197 -0.17657918; -0.104789555 -0.11612058 … 0.053488962 -0.111916974], bias = Float32[-0.09538478, -0.014825672, 0.117425635, -0.10896616, 0.109270275, -0.012288272, -0.06560917, 0.050810307, 0.005339995, 0.039113745  …  -0.0020065606, -0.042473897, -0.09489921, -0.0010632724, -0.013844758, -0.022791013, 0.01943487, -0.016084298, -0.08688739, -0.013523743]), layer_3 = (weight = Float32[-0.29617885 0.06500828 … -0.27692604 -0.2488153; -0.2784669 0.05598557 … -0.13869432 0.09098109; 0.12476622 -0.27931806 … 0.09968187 0.2964681; -0.12457438 0.084029615 … -0.011420269 0.08157834], bias = Float32[-0.0956653, 0.15592338, -0.10216346, 0.10081644])), (layer_1 = (weight = Float32[-0.7695059 0.54293734 0.38837653 0.7071915; -0.082529366 -0.15411279 -0.7949115 -0.54455495; … ; -0.50203806 0.2912404 0.16634601 0.13348013; 0.12655315 -0.08453394 -0.21081413 -0.6428495], bias = Float32[0.18111151, -0.3348568, 0.47446007, -0.0038641738, -0.364452, -0.48625636, -0.14392167, 0.36518282, 0.24286568, 0.15063602  …  0.32514894, -0.35398245, 0.3909067, -0.4643916, -0.33726132, 0.2675287, 0.3739305, 0.19879025, -0.09864276, -0.00495315]), layer_2 = (weight = Float32[0.13905737 0.19288297 … 0.09033301 0.20618886; 0.06723731 0.20876957 … 0.1947639 0.1750701; … ; 0.12302727 -0.016456783 … -0.100621976 -0.17157918; -0.09978955 -0.12112058 … 0.05848896 -0.11691698], bias = Float32[-0.10038478, -0.019825673, 0.11242563, -0.103966154, 0.10427027, -0.017288271, -0.070609175, 0.055810306, 0.00033999514, 0.044113744  …  0.0029934393, -0.0374739, -0.09989921, -0.0060632722, -0.018844757, -0.027791012, 0.014434869, -0.011084299, -0.09188739, -0.008523743]), layer_3 = (weight = Float32[-0.30117884 0.060008284 … -0.27192605 -0.2538153; -0.2734669 0.06098557 … -0.14369431 0.09598109; 0.11976622 -0.28431806 … 0.10468187 0.2914681; -0.119574375 0.08902962 … -0.016420268 0.08657834], bias = Float32[-0.090665296, 0.15092339, -0.09716345, 0.09581644])), (layer_1 = (weight = Float32[-0.77415794 0.54744697 0.3836119 0.7118782; -0.082470536 -0.15390255 -0.7946522 -0.5445939; … ; -0.5062749 0.2953534 0.16202132 0.13775234; 0.13154188 -0.08946696 -0.20580809 -0.6478454], bias = Float32[0.1859277, -0.33550912, 0.46965417, -0.008857113, -0.36131158, -0.4903583, -0.13918692, 0.36736962, 0.24776426, 0.14653559  …  0.3221644, -0.35883018, 0.3957265, -0.46921173, -0.34183937, 0.27211842, 0.3694645, 0.19404387, -0.094268456, -0.009955538]), layer_2 = (weight = Float32[0.13504222 0.19667664 … 0.08633277 0.21026765; 0.062326834 0.21348846 … 0.18984067 0.1799343; … ; 0.11887717 -0.012407768 … -0.104741804 -0.16734572; -0.09581976 -0.12482073 … 0.06242792 -0.12096606], bias = Float32[-0.10437781, -0.024525829, 0.10862694, -0.09903822, 0.099480204, -0.017103007, -0.07427631, 0.06058006, -0.0041402667, 0.04891845  …  0.007930087, -0.032764755, -0.104080245, -0.010771485, -0.023789186, -0.032437388, 0.009640757, -0.006466028, -0.09617484, -0.0045383754]), layer_3 = (weight = Float32[-0.30573422 0.0554925 … -0.26799142 -0.25848314; -0.26893377 0.06547372 … -0.14753108 0.10064726; 0.11475992 -0.28932476 … 0.10918424 0.2864823; -0.11490817 0.09365396 … -0.020457609 0.09133516], bias = Float32[-0.08611815, 0.14637937, -0.09215812, 0.091161616])), (layer_1 = (weight = Float32[-0.7784472 0.5514358 0.37913775 0.71622545; -0.0805904 -0.15524928 -0.792524 -0.54649013; … ; -0.50989956 0.29877883 0.15827593 0.1414288; 0.13651451 -0.094275 -0.20080206 -0.65280724], bias = Float32[0.19051373, -0.3380075, 0.4647567, -0.013837405, -0.3582967, -0.49362695, -0.13478616, 0.36459637, 0.25248504, 0.14325428  …  0.3182262, -0.36321783, 0.4003107, -0.47369504, -0.345822, 0.2763524, 0.36554664, 0.18921791, -0.090440065, -0.014952682]), layer_2 = (weight = Float32[0.1314511 0.2002193 … 0.082776226 0.21400695; 0.05750185 0.21832521 … 0.18504572 0.18475854; … ; 0.115331784 -0.008962317 … -0.1082444 -0.16367407; -0.092612274 -0.12762693 … 0.06561293 -0.12426508], bias = Float32[-0.1081068, -0.029363412, 0.10495155, -0.094359264, 0.09494208, -0.014072283, -0.076658204, 0.06506586, -0.0081013, 0.052774165  …  0.012571787, -0.028363757, -0.10738469, -0.015115489, -0.028451696, -0.036862597, 0.0051993458, -0.002137383, -0.099933036, -0.0014898488]), layer_3 = (weight = Float32[-0.30989394 0.051404454 … -0.26488325 -0.2627834; -0.265002 0.06933495 … -0.15047744 0.10472585; 0.10977174 -0.29429308 … 0.1127075 0.28150427; -0.11064131 0.097845696 … -0.023676563 0.09569541], bias = Float32[-0.08196931, 0.14241463, -0.08716401, 0.08691494])), (layer_1 = (weight = Float32[-0.78241605 0.5549757 0.37495112 0.7202708; -0.077940784 -0.15720427 -0.78963757 -0.5491844; … ; -0.51308984 0.3017166 0.15495385 0.14468096; 0.14140123 -0.09874973 -0.19587937 -0.65763605], bias = Float32[0.1948654, -0.34122184, 0.4598687, -0.018710908, -0.35452545, -0.4962406, -0.13078839, 0.36091018, 0.2568988, 0.14056459  …  0.31392375, -0.3673283, 0.40463546, -0.47771916, -0.34917292, 0.28029904, 0.3620883, 0.18477328, -0.0870033, -0.019972371]), layer_2 = (weight = Float32[0.12802568 0.20404558 … 0.07942832 0.21769916; 0.05321358 0.22305222 … 0.18085112 0.18896061; … ; 0.1121883 -0.005852232 … -0.111335345 -0.16036643; -0.0899377 -0.12978429 … 0.06828076 -0.12703672], bias = Float32[-0.112008914, -0.034107566, 0.10105312, -0.09007942, 0.09070389, -0.010258551, -0.07810972, 0.069268174, -0.011582746, 0.055399004  …  0.016453253, -0.024214584, -0.1099872, -0.019021593, -0.032520182, -0.04116842, 0.0011478653, 0.0020279842, -0.10336012, 0.0008403931]), layer_3 = (weight = Float32[-0.31376415 0.047627546 … -0.2624018 -0.26678166; -0.26169744 0.072574586 … -0.15284705 0.10813269; 0.1049303 -0.29908773 … 0.11520245 0.2767296; -0.10676504 0.10163139 … -0.026285624 0.09962962], bias = Float32[-0.07811634, 0.13905728, -0.08230431, 0.08307263])), (layer_1 = (weight = Float32[-0.7861062 0.5581502 0.3710307 0.7240585; -0.07489895 -0.15946546 -0.78637594 -0.55230254; … ; -0.515958 0.3042903 0.15195285 0.14761999; 0.14602898 -0.10272893 -0.19125065 -0.6621619], bias = Float32[0.19898708, -0.34476656, 0.4550543, -0.023384739, -0.35033512, -0.498381, -0.12722443, 0.35678074, 0.26090553, 0.13830854  …  0.30957782, -0.37147877, 0.4086851, -0.48124382, -0.3520195, 0.28400025, 0.35900882, 0.18115163, -0.08385005, -0.024922695]), layer_2 = (weight = Float32[0.12473722 0.20820393 … 0.076271966 0.22137818; 0.049699772 0.22735754 … 0.17745769 0.19213648; … ; 0.10934157 -0.00292478 … -0.114120714 -0.15731953; -0.08768775 -0.13130593 … 0.07054728 -0.12936868], bias = Float32[-0.116098024, -0.038429055, 0.09706401, -0.08624281, 0.08677984, -0.006041338, -0.07897367, 0.07321769, -0.014644456, 0.057288878  …  0.019264711, -0.020284362, -0.11208828, -0.02233481, -0.035821836, -0.045342103, -0.002547843, 0.0060935616, -0.10655742, 0.002526204]), layer_3 = (weight = Float32[-0.3174117 0.044088315 … -0.2604195 -0.270527; -0.25897893 0.07525056 … -0.15485258 0.110911235; 0.100432225 -0.30352375 … 0.11699138 0.27240297; -0.103192404 0.10510502 … -0.02842101 0.103227176], bias = Float32[-0.07449646, 0.13626906, -0.07777412, 0.07955015])), (layer_1 = (weight = Float32[-0.7895461 0.56102806 0.36735559 0.7276191; -0.07163592 -0.16191429 -0.7829071 -0.55567515; … ; -0.51856357 0.30657375 0.14921995 0.15030552; 0.15015545 -0.10622231 -0.18713075 -0.6661672], bias = Float32[0.20287816, -0.34847602, 0.450342, -0.027807888, -0.34590483, -0.50019544, -0.124045365, 0.35249966, 0.26451027, 0.1364134  …  0.30553854, -0.3758451, 0.4124585, -0.48432583, -0.35458022, 0.287472, 0.35625482, 0.17806411, -0.080924205, -0.029441295]), layer_2 = (weight = Float32[0.12163746 0.21256724 … 0.0733519 0.22497767; 0.046946913 0.23091725 … 0.1748093 0.19426478; … ; 0.10674489 -0.00012070895 … -0.116649404 -0.15449376; -0.08580994 -0.13209864 … 0.072474346 -0.13127813], bias = Float32[-0.12023462, -0.042123877, 0.09336217, -0.08276929, 0.08314633, -0.0017475043, -0.07952675, 0.076961964, -0.017359897, 0.059187874  …  0.021193268, -0.016577538, -0.113859594, -0.024855513, -0.038422607, -0.04927784, -0.005956726, 0.010064345, -0.10955334, 0.0035783492]), layer_3 = (weight = Float32[-0.32085797 0.040760055 … -0.2588373 -0.27403632; -0.25675762 0.07745437 … -0.15661475 0.11316918; 0.09647894 -0.30742297 … 0.118443884 0.268688; -0.099787794 0.1084 … -0.030130204 0.1066255], bias = Float32[-0.07109056, 0.13396277, -0.07377221, 0.07621078])), (layer_1 = (weight = Float32[-0.7927525 0.5636592 0.36390978 0.7309674; -0.06824794 -0.16448542 -0.77932394 -0.55920476; … ; -0.5209368 0.30861428 0.14672852 0.15276703; 0.15359157 -0.10931408 -0.18359366 -0.6694151], bias = Float32[0.20653158, -0.35226837, 0.44574496, -0.031982314, -0.3416363, -0.501788, -0.12114059, 0.34845567, 0.26780483, 0.1348598  …  0.30220857, -0.3804188, 0.41597092, -0.48707798, -0.35707378, 0.29071638, 0.3537934, 0.1747486, -0.0782031, -0.032907467]), layer_2 = (weight = Float32[0.11878177 0.21691671 … 0.070697054 0.22840019; 0.044836 0.23360927 … 0.17277007 0.19551435; … ; 0.10437736 0.0025547706 … -0.11894586 -0.15187997; -0.08426772 -0.132114 … 0.07410493 -0.1327698], bias = Float32[-0.12420989, -0.04518004, 0.09054291, -0.07952078, 0.07976086, 0.0019068341, -0.07997226, 0.08055008, -0.019802315, 0.061645206  …  0.02268656, -0.013120047, -0.115430854, -0.026447643, -0.040531185, -0.052823886, -0.009151464, 0.013912071, -0.112341285, 0.0040274835]), layer_3 = (weight = Float32[-0.3240967 0.037643626 … -0.25756738 -0.27731013; -0.2549241 0.07928972 … -0.15819074 0.11502431; 0.09320356 -0.3106774 … 0.11983449 0.26561943; -0.09642276 0.111640155 … -0.031406496 0.10994009], bias = Float32[-0.06790516, 0.13203105, -0.070427075, 0.072923645])), (layer_1 = (weight = Float32[-0.79574156 0.5660798 0.3606759 0.7341131; -0.06479695 -0.16713375 -0.775685 -0.5628255; … ; -0.52309865 0.31044638 0.14445823 0.1550231; 0.15631819 -0.11208066 -0.18056786 -0.6717888], bias = Float32[0.20994417, -0.35609704, 0.4412766, -0.035929266, -0.33815476, -0.5032171, -0.11840439, 0.34509158, 0.27089164, 0.13363402  …  0.29988706, -0.38515067, 0.4192472, -0.48960632, -0.35963774, 0.2937368, 0.3515977, 0.17095149, -0.07567599, -0.034928627]), layer_2 = (weight = Float32[0.11619356 0.22092398 … 0.06829985 0.23156938; 0.04324155 0.23551138 … 0.17121628 0.19606708; … ; 0.10222324 0.0050734906 … -0.12102939 -0.1494741; -0.0830229 -0.13143495 … 0.075476505 -0.133873], bias = Float32[-0.12782104, -0.047687534, 0.08897833, -0.076389425, 0.07659159, 0.0041352334, -0.08042742, 0.084016606, -0.022028724, 0.06474154  …  0.024114875, -0.009929766, -0.11688118, -0.027138997, -0.04234918, -0.055867586, -0.012177843, 0.017606178, -0.11491193, 0.0039530206]), layer_3 = (weight = Float32[-0.32711828 0.034743562 … -0.25653687 -0.28035188; -0.25338203 0.08084589 … -0.15960678 0.116571866; 0.09062246 -0.31328544 … 0.121306784 0.263132; -0.09302914 0.11489253 … -0.032247134 0.11322661], bias = Float32[-0.064948715, 0.13037996, -0.06775296, 0.06961754])), (layer_1 = (weight = Float32[-0.7985383 0.5683205 0.3576271 0.737076; -0.06131981 -0.1698291 -0.7720229 -0.5664928; … ; -0.5250728 0.31209943 0.14238362 0.15709467; 0.1584856 -0.114602886 -0.17791773 -0.67344135], bias = Float32[0.21313144, -0.35993537, 0.436965, -0.039643142, -0.33556855, -0.50449854, -0.11579506, 0.34253487, 0.27382255, 0.1326872  …  0.29855743, -0.38998726, 0.422314, -0.49196604, -0.36228624, 0.2965499, 0.3496336, 0.16680568, -0.07332828, -0.035726447]), layer_2 = (weight = Float32[0.11385885 0.2243997 … 0.06613135 0.23447853; 0.042067643 0.23668824 … 0.17005517 0.19604148; … ; 0.10026246 0.0074253213 … -0.122922055 -0.14726232; -0.08203131 -0.13025206 … 0.076623924 -0.13465104], bias = Float32[-0.13098954, -0.049729142, 0.088459, -0.07335011, 0.07364477, 0.0050802696, -0.08091746, 0.0873696, -0.024069471, 0.068300016  …  0.025580436, -0.0069893757, -0.118236385, -0.027119737, -0.043967295, -0.05841312, -0.015030997, 0.021141376, -0.11727281, 0.0034768782]), layer_3 = (weight = Float32[-0.32992962 0.032049794 … -0.25569668 -0.28317922; -0.2520721 0.08217948 … -0.16088562 0.117873885; 0.08864197 -0.31534067 … 0.12288206 0.26111114; -0.089623116 0.1181456 … -0.032703552 0.11647462], bias = Float32[-0.062212717, 0.1289515, -0.06565841, 0.06630551]))  …  (layer_1 = (weight = Float32[-0.8617641 0.46417505 0.41542935 0.7306285; -0.17157699 -0.015206084 -0.94998646 -0.34225953; … ; -0.5953644 -0.004186425 0.19018964 0.17380862; 0.28578147 0.3132422 -0.56452125 -0.6191688], bias = Float32[0.18829797, -0.68605125, 0.574121, -0.34595692, -0.120001934, -0.6760025, 0.2104961, 0.19397916, 0.42903757, 0.2335338  …  0.30219465, -0.4371226, 0.2658142, -0.5137027, -0.24723434, -0.28571874, 0.32078326, 0.28251475, 0.098387636, -0.16169462]), layer_2 = (weight = Float32[-0.00021404144 0.38549307 … -0.014354283 0.54058504; 0.10381855 -0.07908282 … 0.11731824 -0.02885043; … ; 0.12699276 0.14298792 … -0.087254874 -0.18339847; 0.1919698 0.5961368 … 0.5169506 -0.26939753], bias = Float32[-0.32551542, 0.102545306, -0.003129791, 0.09840133, 0.101581074, -0.04502873, -0.098200664, 0.2205432, -0.06247879, 0.20233653  …  -0.070962034, 0.047442716, -0.21800405, 0.036069117, 0.10522691, -0.13551873, -0.08045002, 0.05568812, -0.21403994, -0.018703865]), layer_3 = (weight = Float32[-0.2345467 -0.0041446122 … -0.15866281 -0.3710875; -0.45959264 0.075214036 … -0.13350531 0.28641176; 0.2741871 -0.4775424 … 0.2679575 0.20913236; 0.12523606 0.16466995 … 0.01894504 0.19766247], bias = Float32[-0.010392048, 0.07850496, -0.14373294, 0.008765191])), (layer_1 = (weight = Float32[-0.86184365 0.4641885 0.41542095 0.7307021; -0.17181262 -0.015344103 -0.95001 -0.34209117; … ; -0.59548753 -0.0042846105 0.19004993 0.17394514; 0.28575256 0.3132717 -0.56439155 -0.61921084], bias = Float32[0.18837088, -0.6861591, 0.5745823, -0.34604326, -0.11997373, -0.6761404, 0.21059, 0.19410528, 0.42909056, 0.23363602  …  0.30243006, -0.43708462, 0.26586393, -0.5139256, -0.24718717, -0.28568563, 0.32064772, 0.2826225, 0.09866776, -0.16174223]), layer_2 = (weight = Float32[-0.0001696008 0.3855035 … -0.014321462 0.54056394; 0.10384142 -0.07903878 … 0.11732079 -0.02879647; … ; 0.12696157 0.14305244 … -0.087272644 -0.18345207; 0.19198896 0.59635746 … 0.5170098 -0.2695527], bias = Float32[-0.32553616, 0.10282825, -0.003141831, 0.09853901, 0.10174213, -0.0450852, -0.0982624, 0.2205929, -0.062473662, 0.20243135  …  -0.07098363, 0.04752122, -0.21805733, 0.0361213, 0.10527748, -0.1355253, -0.08043526, 0.055746302, -0.21411324, -0.018720727]), layer_3 = (weight = Float32[-0.23448783 -0.004148848 … -0.15866666 -0.37109047; -0.4596461 0.075188406 … -0.13341884 0.28648975; 0.2741246 -0.4776628 … 0.26802364 0.20912573; 0.12522992 0.16467601 … 0.018951433 0.19767258], bias = Float32[-0.010402217, 0.07845447, -0.14379351, 0.008766052])), (layer_1 = (weight = Float32[-0.8619223 0.46420187 0.41541234 0.730775; -0.17204793 -0.015483395 -0.9500329 -0.34192318; … ; -0.59561056 -0.0043831053 0.18991056 0.17408139; 0.28572515 0.3133024 -0.56426084 -0.61925435], bias = Float32[0.18844318, -0.68626755, 0.57504064, -0.34612876, -0.11994556, -0.67628056, 0.2106832, 0.19423129, 0.42914274, 0.23373887  …  0.3026639, -0.43704772, 0.2659125, -0.514147, -0.24714023, -0.28565267, 0.32051194, 0.28272885, 0.098945856, -0.1617888]), layer_2 = (weight = Float32[-0.00012594856 0.38551337 … -0.014289217 0.5405433; 0.103864945 -0.07899416 … 0.11732406 -0.028743766; … ; 0.12692873 0.14311634 … -0.08729173 -0.18350431; 0.19200686 0.59657735 … 0.5170683 -0.26970822], bias = Float32[-0.32555676, 0.1031094, -0.0031528026, 0.09867723, 0.10190156, -0.04514174, -0.09832457, 0.22064131, -0.062468957, 0.20252572  …  -0.071005635, 0.047598526, -0.21811078, 0.036172107, 0.105326906, -0.13553233, -0.08042063, 0.0558032, -0.214188, -0.018738309]), layer_3 = (weight = Float32[-0.23442766 -0.004152262 … -0.15866998 -0.3710929; -0.45969948 0.07516336 … -0.13333192 0.2865688; 0.2740636 -0.47778246 … 0.26808926 0.20911837; 0.12522419 0.16468221 … 0.018957345 0.19768244], bias = Float32[-0.010412296, 0.07840421, -0.14385355, 0.008767125])), (layer_1 = (weight = Float32[-0.8619996 0.46421528 0.4154033 0.73084676; -0.17228226 -0.015624616 -0.9500545 -0.34175637; … ; -0.5957331 -0.0044814865 0.18977267 0.17421657; 0.28569898 0.31333554 -0.56413037 -0.61929905], bias = Float32[0.18851437, -0.6863769, 0.5754966, -0.34621373, -0.119916715, -0.6764232, 0.21077424, 0.19435774, 0.429193, 0.23384085  …  0.30289608, -0.43701133, 0.26595876, -0.5143653, -0.2470934, -0.28561917, 0.32037637, 0.28283313, 0.09922182, -0.16183369]), layer_2 = (weight = Float32[-8.343584f-5 0.38552192 … -0.014258008 0.54052395; 0.10388981 -0.07894868 … 0.11732876 -0.028693452; … ; 0.12689225 0.14317928 … -0.08731376 -0.18355373; 0.19202279 0.59679574 … 0.5171258 -0.26986426], bias = Float32[-0.32557648, 0.103387825, -0.003162532, 0.098816134, 0.10205799, -0.045199256, -0.09838772, 0.22068709, -0.062465712, 0.20261994  …  -0.07102904, 0.047673836, -0.21816523, 0.036221776, 0.10537412, -0.13554023, -0.08040629, 0.0558581, -0.21426679, -0.018757109]), layer_3 = (weight = Float32[-0.23436636 -0.004154356 … -0.15867169 -0.3710939; -0.45975345 0.075138934 … -0.13324408 0.28664997; 0.27400485 -0.47790074 … 0.2681545 0.2091099; 0.12521851 0.16468863 … 0.018963218 0.19769233], bias = Float32[-0.010422442, 0.07835452, -0.14391322, 0.008768248])), (layer_1 = (weight = Float32[-0.8620756 0.46422872 0.41539383 0.7309176; -0.17251563 -0.015767433 -0.9500752 -0.3415906; … ; -0.5958552 -0.004579738 0.18963602 0.1743508; 0.28567407 0.31337082 -0.56399995 -0.61934495], bias = Float32[0.18858466, -0.68648684, 0.5759498, -0.34629804, -0.11988732, -0.67656773, 0.21086344, 0.1944845, 0.42924157, 0.23394227  …  0.3031267, -0.43697524, 0.266003, -0.5145808, -0.24704707, -0.2855852, 0.3202408, 0.2829355, 0.09949575, -0.16187702]), layer_2 = (weight = Float32[-4.1894433f-5 0.3855294 … -0.014227746 0.54050577; 0.10391583 -0.078902535 … 0.11733473 -0.028645338; … ; 0.12685254 0.14324163 … -0.087338425 -0.18360041; 0.19203669 0.5970129 … 0.5171821 -0.27002048], bias = Float32[-0.32559526, 0.103663534, -0.0031711499, 0.09895548, 0.102211654, -0.04525759, -0.0984517, 0.22073045, -0.062463757, 0.20271401  …  -0.071053624, 0.04774744, -0.21822055, 0.03627042, 0.10541929, -0.13554886, -0.08039223, 0.055911373, -0.21434897, -0.018777043]), layer_3 = (weight = Float32[-0.23430417 -0.0041553504 … -0.15867199 -0.37109363; -0.45980802 0.07511499 … -0.1331555 0.286733; 0.27394816 -0.47801772 … 0.2682194 0.20910054; 0.12521285 0.16469522 … 0.018969033 0.19770221], bias = Float32[-0.010432645, 0.07830536, -0.14397256, 0.008769403])), (layer_1 = (weight = Float32[-0.86215127 0.46424222 0.41538405 0.7309882; -0.17274857 -0.015910732 -0.9500959 -0.34142485; … ; -0.595977 -0.004678199 0.18949933 0.17448485; 0.28565055 0.3134068 -0.5638683 -0.6193922], bias = Float32[0.18865487, -0.6865967, 0.5763992, -0.3463813, -0.11985826, -0.67671293, 0.21095246, 0.19461085, 0.42928976, 0.23404486  …  0.3033558, -0.4369396, 0.26604685, -0.51479495, -0.24700202, -0.28555137, 0.32010457, 0.2830369, 0.09976789, -0.1619196]), layer_2 = (weight = Float32[-7.492854f-7 0.38553673 … -0.014197893 0.54048806; 0.10394213 -0.07885637 … 0.117341086 -0.028598152; … ; 0.1268119 0.14330421 … -0.08736387 -0.18364578; 0.19204909 0.5972299 … 0.5172375 -0.27017617], bias = Float32[-0.32561356, 0.103937246, -0.0031790154, 0.09909473, 0.102364086, -0.04531573, -0.098515764, 0.22077295, -0.062461924, 0.20280768  …  -0.071078196, 0.047820505, -0.21827573, 0.036318023, 0.10546363, -0.13555764, -0.080378234, 0.0559642, -0.21443124, -0.01879754]), layer_3 = (weight = Float32[-0.23424146 -0.0041560302 … -0.1586721 -0.37109312; -0.45986256 0.075091235 … -0.1330669 0.2868165; 0.27389258 -0.47813407 … 0.268284 0.20909093; 0.12520744 0.1647018 … 0.018974392 0.19771177], bias = Float32[-0.010442742, 0.07825638, -0.14403136, 0.008770698])), (layer_1 = (weight = Float32[-0.8622273 0.4642558 0.41537413 0.7310595; -0.1729817 -0.016053284 -0.95011777 -0.34125793; … ; -0.596099 -0.0047772466 0.18936117 0.17461957; 0.28562862 0.31344184 -0.56373394 -0.61944103], bias = Float32[0.18872589, -0.6867057, 0.5768438, -0.34646314, -0.119830444, -0.67685735, 0.21104312, 0.194736, 0.429339, 0.23415057  …  0.30358365, -0.43690452, 0.26609197, -0.5150094, -0.24695913, -0.2855184, 0.3199669, 0.28313822, 0.10003854, -0.16196227]), layer_2 = (weight = Float32[4.0635132f-5 0.38554493 … -0.014167866 0.5404702; 0.10396773 -0.07881088 … 0.11734687 -0.028550504; … ; 0.12677288 0.14336799 … -0.08738805 -0.18369134; 0.19206049 0.59744793 … 0.51729214 -0.27033052], bias = Float32[-0.32563192, 0.10420973, -0.0031865414, 0.09923329, 0.10251695, -0.0453726, -0.09857915, 0.22081627, -0.062458962, 0.20290072  …  -0.07110143, 0.04789429, -0.21832973, 0.0363646, 0.10550843, -0.13556589, -0.08036414, 0.056017872, -0.21451002, -0.018818002]), layer_3 = (weight = Float32[-0.23417859 -0.004157277 … -0.15867333 -0.37109348; -0.45991638 0.07506733 … -0.13297902 0.2868989; 0.2738371 -0.47825053 … 0.26834825 0.20908175; 0.12520249 0.16470817 … 0.018978883 0.1977206], bias = Float32[-0.010452604, 0.07820716, -0.14408948, 0.008772223])), (layer_1 = (weight = Float32[-0.86230415 0.46426952 0.4153641 0.73113173; -0.17321514 -0.01619443 -0.95014143 -0.34108937; … ; -0.59622115 -0.004876953 0.18922096 0.1747552; 0.28560817 0.31347528 -0.5635964 -0.6194913], bias = Float32[0.1887981, -0.68681335, 0.5772831, -0.34654343, -0.11980419, -0.67700016, 0.21113612, 0.19485967, 0.42938977, 0.23426017  …  0.30381036, -0.43686974, 0.26613912, -0.51522475, -0.24691905, -0.28548652, 0.31982732, 0.28323993, 0.100308016, -0.16200545]), layer_2 = (weight = Float32[8.260316f-5 0.38555446 … -0.014137436 0.540452; 0.1039922 -0.07876647 … 0.11735165 -0.028501863; … ; 0.12673648 0.14343357 … -0.08741018 -0.18373749; 0.19207086 0.5976676 … 0.517346 -0.27048296], bias = Float32[-0.32565027, 0.10448109, -0.003194009, 0.09937077, 0.10267084, -0.04542785, -0.0986416, 0.220861, -0.06245445, 0.2029932  …  -0.071122766, 0.04796937, -0.21838222, 0.036410365, 0.10555413, -0.13557327, -0.08034994, 0.05607302, -0.21458372, -0.018838266]), layer_3 = (weight = Float32[-0.23411605 -0.0041595846 … -0.1586761 -0.37109518; -0.45996943 0.075042926 … -0.13289219 0.2869795; 0.2737812 -0.47836745 … 0.26841232 0.20907335; 0.12519793 0.16471416 … 0.018982457 0.19772854], bias = Float32[-0.010462298, 0.07815745, -0.14414695, 0.008773909])), (layer_1 = (weight = Float32[-0.8623815 0.46428338 0.41535383 0.73120475; -0.17344846 -0.016334342 -0.9501667 -0.34091967; … ; -0.59634304 -0.0049769925 0.18907928 0.17489122; 0.28558892 0.31350777 -0.56345654 -0.6195427], bias = Float32[0.18887128, -0.6869196, 0.5777174, -0.34662247, -0.119779006, -0.67714125, 0.21123068, 0.19498219, 0.42944148, 0.2343729  …  0.30403608, -0.43683466, 0.26618767, -0.51544005, -0.246882, -0.28545532, 0.31968582, 0.2833417, 0.100576535, -0.16204886]), layer_2 = (weight = Float32[0.00012506914 0.385565 … -0.0141068585 0.540434; 0.10401584 -0.07872314 … 0.117355786 -0.028452855; … ; 0.12670167 0.14350107 … -0.08743115 -0.18378326; 0.19207951 0.5978885 … 0.5173986 -0.27063337], bias = Float32[-0.325668, 0.10475068, -0.0032014686, 0.09950708, 0.10282494, -0.045482084, -0.09870349, 0.22090632, -0.062449053, 0.20308547  …  -0.071142696, 0.048045363, -0.2184338, 0.036455706, 0.10560007, -0.13557984, -0.080335855, 0.05612937, -0.21465355, -0.018858712]), layer_3 = (weight = Float32[-0.23405431 -0.004162859 … -0.15867968 -0.3710978; -0.4600223 0.075017825 … -0.13280608 0.28705874; 0.27372518 -0.47848442 … 0.26847652 0.20906554; 0.12519336 0.16471972 … 0.018985499 0.19773571], bias = Float32[-0.010472094, 0.078107275, -0.14420405, 0.00877552])), (layer_1 = (weight = Float32[-0.8623815 0.46428338 0.41535383 0.73120475; -0.17344846 -0.016334342 -0.9501667 -0.34091967; … ; -0.59634304 -0.0049769925 0.18907928 0.17489122; 0.28558892 0.31350777 -0.56345654 -0.6195427], bias = Float32[0.18887128, -0.6869196, 0.5777174, -0.34662247, -0.119779006, -0.67714125, 0.21123068, 0.19498219, 0.42944148, 0.2343729  …  0.30403608, -0.43683466, 0.26618767, -0.51544005, -0.246882, -0.28545532, 0.31968582, 0.2833417, 0.100576535, -0.16204886]), layer_2 = (weight = Float32[0.00012506914 0.385565 … -0.0141068585 0.540434; 0.10401584 -0.07872314 … 0.117355786 -0.028452855; … ; 0.12670167 0.14350107 … -0.08743115 -0.18378326; 0.19207951 0.5978885 … 0.5173986 -0.27063337], bias = Float32[-0.325668, 0.10475068, -0.0032014686, 0.09950708, 0.10282494, -0.045482084, -0.09870349, 0.22090632, -0.062449053, 0.20308547  …  -0.071142696, 0.048045363, -0.2184338, 0.036455706, 0.10560007, -0.13557984, -0.080335855, 0.05612937, -0.21465355, -0.018858712]), layer_3 = (weight = Float32[-0.23405431 -0.004162859 … -0.15867968 -0.3710978; -0.4600223 0.075017825 … -0.13280608 0.28705874; 0.27372518 -0.47848442 … 0.26847652 0.20906554; 0.12519336 0.16471972 … 0.018985499 0.19773571], bias = Float32[-0.010472094, 0.078107275, -0.14420405, 0.00877552]))], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), Float32[192.39883, 135.20009, 113.53052, 102.90518, 96.61792, 92.35046, 89.20726, 86.74313, 84.62893, 82.73641  …  5.7015543, 5.6941614, 5.686838, 5.6795506, 5.672301, 5.665117, 5.6579885, 5.6508975, 5.6438537, 5.6438537])

We can now animate the training to get a better understanding of the fit.

function predict(y0, t, p, state)
    node, _, _ = neural_ode(t, length(y0))
    Array(node(y0, p, state)[1])
end

function plot_pred(t_train, y_train, t_grid, rescale_t, rescale_y,
        num_iters, p, state, loss, y0 = y_train[:, 1])
    y_pred = predict(y0, t_grid, p, state)
    return plot_result(rescale_t(t_train), rescale_y(y_train),
        rescale_t(t_grid), rescale_y(y_pred), loss, num_iters)
end

function plot_pred(t, y, y_pred)
    plt = Plots.scatter(t, y; label = "Observation")
    Plots.plot!(plt, t, y_pred; label = "Prediction")
end

function plot_pred(t, y, t_pred, y_pred; kwargs...)
    plot_params = zip(eachrow(y), eachrow(y_pred), FEATURE_NAMES, UNITS)
    map(enumerate(plot_params)) do (i, (yᵢ, ŷᵢ, name, unit))
        plt = Plots.plot(t_pred, ŷᵢ; label = "Prediction", color = i,
            linewidth = 3, legend = nothing, title = name, kwargs...)
        Plots.scatter!(plt, t, yᵢ; label = "Observation", xlabel = "Time",
            ylabel = unit, markersize = 5, color = i)
    end
end

function plot_result(t, y, t_pred, y_pred, loss, num_iters; kwargs...)
    plts_preds = plot_pred(t, y, t_pred, y_pred; kwargs...)
    plot!(plts_preds[1]; ylim = (10, 40), legend = (0.65, 1.0))
    plot!(plts_preds[2]; ylim = (20, 100))
    plot!(plts_preds[3]; ylim = (2, 12))
    plot!(plts_preds[4]; ylim = (990, 1025))

    p_loss = Plots.plot(loss; label = nothing, linewidth = 3, title = "Loss",
        xlabel = "Iterations", xlim = (0, num_iters))
    plots = [plts_preds..., p_loss]
    plot(plots...; layout = grid(length(plots), 1), size = (900, 900))
end

function animate_training(
        plot_frame, t_train, y_train, ps, losses, obs_grid; pause_for = 300)
    obs_count = Dict(i - 1 => n for (i, n) in enumerate(obs_grid))
    is = [min(i, length(losses)) for i in 2:(length(losses) + pause_for)]
    @animate for i in is
        stage = Int(floor((i - 1) / length(losses) * length(obs_grid)))
        k = obs_count[stage]
        plot_frame(t_train[1:k], y_train[:, 1:k], ps[i], losses[1:i])
    end every 2
end

num_iters = length(losses)
t_train_grid = collect(range(extrema(t_train)...; length = 500))
rescale_t(x) = t_scale .* x .+ t_mean
rescale_y(x) = y_scale .* x .+ y_mean
function plot_frame(t, y, p, loss)
    plot_pred(t, y, t_train_grid, rescale_t, rescale_y, num_iters, p, state, loss)
end
anim = animate_training(plot_frame, t_train, y_train, ps, losses, obs_grid)
gif(anim, "node_weather_forecast_training.gif")
Example block output

Looks good! But how well does the model forecast?

function plot_extrapolation(t_train, y_train, t_test, y_test, t̂, ŷ)
    plts = plot_pred(t_train, y_train, t̂, ŷ)
    for (i, (plt, y)) in enumerate(zip(plts, eachrow(y_test)))
        scatter!(plt, t_test, y; color = i, markerstrokecolor = :white,
            label = "Test observation")
    end

    plot!(plts[1]; ylim = (10, 40), legend = :topleft)
    plot!(plts[2]; ylim = (20, 100))
    plot!(plts[3]; ylim = (2, 12))
    plot!(plts[4]; ylim = (990, 1025))
    plot(plts...; layout = grid(length(plts), 1), size = (900, 900))
end

t_grid = collect(range(minimum(t_train), maximum(t_test); length = 500))
y_pred = predict(y_train[:, 1], t_grid, ps[end], state)
plot_extrapolation(rescale_t(t_train), rescale_y(y_train), rescale_t(t_test),
    rescale_y(y_test), rescale_t(t_grid), rescale_y(y_pred))
Example block output

While there is some drift in the weather patterns, the model extrapolates very well!