Weather forecasting with neural ODEs

In this example we are going to apply neural ODEs to a multidimensional weather dataset and use it for weather forecasting. This example is adapted from Forecasting the weather with neural ODEs - Sebatian Callh personal blog.

The data

The data is a four-dimensional dataset of daily temperature, humidity, wind speed and pressure measured over four years in the city Delhi. Let us download and plot it.

using Random, Dates, Optimization, ComponentArrays, Lux, OptimizationOptimisers, DiffEqFlux,
      OrdinaryDiffEq, CSV, DataFrames, Dates, Statistics, Plots
using Downloads: download

function download_data(
        data_url = "https://raw.githubusercontent.com/SebastianCallh/neural-ode-weather-forecast/master/data/",
        data_local_path = "./delhi")
    function load(file_name)
        download("$data_url/$file_name", joinpath(data_local_path, file_name))
        return CSV.read(joinpath(data_local_path, file_name), DataFrame)
    end

    mkpath(data_local_path)
    train_df = load("DailyDelhiClimateTrain.csv")
    test_df = load("DailyDelhiClimateTest.csv")
    return vcat(train_df, test_df)
end

df = download_data()
5×5 DataFrame
Rowdatemeantemphumiditywind_speedmeanpressure
DateFloat64Float64Float64Float64
12013-01-0110.084.50.01015.67
22013-01-027.492.02.981017.8
32013-01-037.1666787.04.633331018.67
42013-01-048.6666771.33331.233331017.17
52013-01-056.086.83333.71016.5
FEATURES = [:meantemp, :humidity, :wind_speed, :meanpressure]
UNITS = ["Celsius", "g/m³ of water", "km/h", "hPa"]
FEATURE_NAMES = ["Mean temperature", "Humidity", "Wind speed", "Mean pressure"]

function plot_data(df)
    plots = map(enumerate(zip(FEATURES, FEATURE_NAMES, UNITS))) do (i, (f, n, u))
        plot(df[:, :date], df[:, f]; title = n, label = nothing,
            ylabel = u, size = (800, 600), color = i)
    end

    n = length(plots)
    plot(plots...; layout = (Int(n / 2), Int(n / 2)))
end

plot_data(df)
Example block output

The data show clear annual behaviour (it is difficult to see for pressure due to wild measurement errors but the pattern is there). It is concievable that this system can be described with an ODE, but which? Let us use an network to learn the dynamics from the dataset. Training neural networks is easier with standardised data so we will compute standardised features before training. Finally, we take the first 20 days for training and the rest for testing.

function standardize(x)
    μ = mean(x; dims = 2)
    σ = std(x; dims = 2)
    z = (x .- μ) ./ σ
    return z, μ, σ
end

function featurize(raw_df, num_train = 20)
    raw_df.year = Float64.(year.(raw_df.date))
    raw_df.month = Float64.(month.(raw_df.date))
    df = combine(groupby(raw_df, [:year, :month]),
        :date => (d -> mean(year.(d)) .+ mean(month.(d)) ./ 12),
        :meantemp => mean, :humidity => mean, :wind_speed => mean,
        :meanpressure => mean; renamecols = false)
    t_and_y(df) = df.date', Matrix(select(df, FEATURES))'
    t_train, y_train = t_and_y(df[1:num_train, :])
    t_test, y_test = t_and_y(df[(num_train + 1):end, :])
    t_train, t_mean, t_scale = standardize(t_train)
    y_train, y_mean, y_scale = standardize(y_train)
    t_test = (t_test .- t_mean) ./ t_scale
    y_test = (y_test .- y_mean) ./ y_scale

    return (
        vec(t_train), y_train, vec(t_test), y_test, (t_mean, t_scale), (y_mean, y_scale))
end

function plot_features(t_train, y_train, t_test, y_test)
    plt_split = plot(reshape(t_train, :), y_train'; linewidth = 3, colors = 1:4,
        xlabel = "Normalized time", ylabel = "Normalized values",
        label = nothing, title = "Features")
    plot!(plt_split, reshape(t_test, :), y_test'; linewidth = 3,
        linestyle = :dash, color = [1 2 3 4], label = nothing)
    plot!(plt_split, [0], [0]; linewidth = 0, label = "Train", color = 1)
    plot!(plt_split, [0], [0]; linewidth = 0, linestyle = :dash,
        label = "Test", color = 1, ylims = (-5, 5))
end

t_train, y_train, t_test, y_test, (t_mean, t_scale), (y_mean, y_scale) = featurize(df)
plot_features(t_train, y_train, t_test, y_test)
Example block output

The dataset is now centered around 0 with a standard deviation of 1. We will ignore the extreme pressure measurements for simplicity. Since they are in the test split they won't impact training anyway. We are now ready to construct and train our model! To avoid local minimas we will train iteratively with increasing amounts of data.

function neural_ode(t, data_dim)
    f = Chain(Dense(data_dim => 64, swish), Dense(64 => 32, swish), Dense(32 => data_dim))

    node = NeuralODE(f, extrema(t), Tsit5(); saveat = t, abstol = 1e-6, reltol = 1e-3)

    rng = Xoshiro(0)
    p, state = Lux.setup(rng, f)

    return node, ComponentArray(p), state
end

function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kwargs...)
    predict(p) = Array(node(y0, p, state)[1])
    loss(p) = sum(abs2, predict(p) .- y)

    adtype = Optimization.AutoZygote()
    optf = OptimizationFunction((p, _) -> loss(p), adtype)
    optprob = OptimizationProblem(optf, p)
    res = solve(optprob, opt; maxiters = maxiters, kwargs...)
    res.minimizer, state
end

function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...)
    log_results(ps, losses) = (state, loss) -> begin
        push!(ps, copy(state.u))
        push!(losses, loss)
        false
    end

    ps, losses = ComponentArray[], Float32[]
    for k in obs_grid
        node, p_new, state_new = neural_ode(t, size(y, 1))
        p === nothing && (p = p_new)
        state === nothing && (state = state_new)

        p,
        state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
            maxiters, rng; callback = log_results(ps, losses), kwargs...)
    end
    ps, state, losses
end

rng = MersenneTwister(123)
obs_grid = 4:4:length(t_train) # we train on an increasing amount of the first k obs
maxiters = 150
lr = 5e-3
ps, state, losses = train(t_train, y_train, obs_grid, maxiters, lr, rng; progress = true)
(ComponentArrays.ComponentArray[(layer_1 = (weight = Float32[-0.7645059 0.53793734 0.39337653 0.70219153; -0.07752936 -0.15911278 -0.7899115 -0.54955494; … ; -0.49703807 0.2862404 0.17134601 0.12848014; 0.12155315 -0.079533935 -0.21581413 -0.6378495], bias = Float32[0.17611152, -0.3398568, 0.47946006, 0.0011358261, -0.369452, -0.48125637, -0.14892167, 0.36018282, 0.23786569, 0.15563601  …  0.32014894, -0.34898245, 0.3859067, -0.4593916, -0.33226132, 0.26252872, 0.3789305, 0.20379025, -0.10364276, 4.684925f-5]), layer_2 = (weight = Float32[0.14405736 0.18788297 … 0.09533301 0.20118886; 0.07223731 0.20376958 … 0.1997639 0.17007011; … ; 0.12802728 -0.021456784 … -0.09562197 -0.17657918; -0.104789555 -0.11612058 … 0.053488962 -0.111916974], bias = Float32[-0.09538478, -0.014825672, 0.117425635, -0.10896616, 0.109270275, -0.012288272, -0.06560917, 0.050810307, 0.005339995, 0.039113745  …  -0.0020065606, -0.042473897, -0.09489921, -0.0010632724, -0.013844758, -0.022791013, 0.01943487, -0.016084298, -0.08688739, -0.013523743]), layer_3 = (weight = Float32[-0.29617885 0.06500828 … -0.27692604 -0.2488153; -0.2784669 0.05598557 … -0.13869432 0.09098109; 0.12476622 -0.27931806 … 0.09968187 0.2964681; -0.12457438 0.084029615 … -0.011420269 0.08157834], bias = Float32[-0.0956653, 0.15592338, -0.10216346, 0.10081644])), (layer_1 = (weight = Float32[-0.7695059 0.54293734 0.38837653 0.7071915; -0.082529366 -0.15411279 -0.7949115 -0.54455495; … ; -0.50203806 0.2912404 0.16634601 0.13348013; 0.12655315 -0.08453394 -0.21081413 -0.6428495], bias = Float32[0.18111151, -0.3348568, 0.47446007, -0.0038641738, -0.364452, -0.48625636, -0.14392167, 0.36518282, 0.24286568, 0.15063602  …  0.32514894, -0.35398245, 0.3909067, -0.4643916, -0.33726132, 0.2675287, 0.3739305, 0.19879025, -0.09864276, -0.00495315]), layer_2 = (weight = Float32[0.13905737 0.19288297 … 0.09033301 0.20618886; 0.06723731 0.20876957 … 0.1947639 0.1750701; … ; 0.12302727 -0.016456783 … -0.100621976 -0.17157918; -0.09978955 -0.12112058 … 0.05848896 -0.11691698], bias = Float32[-0.10038478, -0.019825673, 0.11242563, -0.103966154, 0.10427027, -0.017288271, -0.070609175, 0.055810306, 0.00033999514, 0.044113744  …  0.0029934393, -0.0374739, -0.09989921, -0.0060632722, -0.018844757, -0.027791012, 0.014434869, -0.011084299, -0.09188739, -0.008523743]), layer_3 = (weight = Float32[-0.30117884 0.060008284 … -0.27192605 -0.2538153; -0.2734669 0.06098557 … -0.14369431 0.09598109; 0.11976622 -0.28431806 … 0.10468187 0.2914681; -0.119574375 0.08902962 … -0.016420268 0.08657834], bias = Float32[-0.090665296, 0.15092339, -0.09716345, 0.09581644])), (layer_1 = (weight = Float32[-0.77415794 0.54744697 0.3836119 0.7118782; -0.082470536 -0.15390255 -0.7946522 -0.5445939; … ; -0.5062749 0.2953534 0.16202132 0.13775234; 0.13154188 -0.08946696 -0.20580809 -0.6478454], bias = Float32[0.1859277, -0.33550912, 0.46965417, -0.008857113, -0.36131158, -0.4903583, -0.13918692, 0.36736962, 0.24776426, 0.14653559  …  0.3221644, -0.35883018, 0.3957265, -0.46921173, -0.34183937, 0.27211842, 0.3694645, 0.19404387, -0.094268456, -0.009955538]), layer_2 = (weight = Float32[0.13504222 0.19667664 … 0.08633277 0.21026765; 0.062326834 0.21348846 … 0.18984067 0.1799343; … ; 0.11887717 -0.012407768 … -0.104741804 -0.16734572; -0.09581976 -0.12482073 … 0.06242792 -0.12096606], bias = Float32[-0.10437781, -0.024525829, 0.10862694, -0.09903822, 0.099480204, -0.017103007, -0.07427631, 0.06058006, -0.0041402667, 0.04891845  …  0.007930087, -0.032764755, -0.104080245, -0.010771485, -0.023789186, -0.032437388, 0.009640757, -0.006466028, -0.09617484, -0.0045383754]), layer_3 = (weight = Float32[-0.30573422 0.0554925 … -0.26799142 -0.25848314; -0.26893377 0.06547372 … -0.14753108 0.10064726; 0.11475992 -0.28932476 … 0.10918424 0.2864823; -0.11490817 0.09365396 … -0.020457609 0.09133516], bias = Float32[-0.08611815, 0.14637937, -0.09215812, 0.091161616])), (layer_1 = (weight = Float32[-0.7784472 0.5514358 0.37913775 0.71622545; -0.0805904 -0.15524928 -0.792524 -0.54649013; … ; -0.50989956 0.29877883 0.15827593 0.1414288; 0.13651451 -0.094275 -0.20080206 -0.65280724], bias = Float32[0.19051373, -0.3380075, 0.4647567, -0.013837405, -0.3582967, -0.49362695, -0.13478616, 0.36459637, 0.25248504, 0.14325428  …  0.3182262, -0.36321783, 0.4003107, -0.47369504, -0.345822, 0.2763524, 0.36554664, 0.18921791, -0.090440065, -0.014952682]), layer_2 = (weight = Float32[0.1314511 0.2002193 … 0.082776226 0.21400695; 0.05750185 0.21832521 … 0.18504572 0.18475854; … ; 0.115331784 -0.008962317 … -0.1082444 -0.16367407; -0.092612274 -0.12762693 … 0.06561293 -0.12426508], bias = Float32[-0.1081068, -0.029363412, 0.10495155, -0.094359264, 0.09494208, -0.014072283, -0.076658204, 0.06506586, -0.0081013, 0.052774165  …  0.012571787, -0.028363757, -0.10738469, -0.015115489, -0.028451696, -0.036862597, 0.0051993458, -0.002137383, -0.099933036, -0.0014898488]), layer_3 = (weight = Float32[-0.30989394 0.051404454 … -0.26488325 -0.2627834; -0.265002 0.06933495 … -0.15047744 0.10472585; 0.10977174 -0.29429308 … 0.1127075 0.28150427; -0.11064131 0.097845696 … -0.023676563 0.09569541], bias = Float32[-0.08196931, 0.14241463, -0.08716401, 0.08691494])), (layer_1 = (weight = Float32[-0.78241605 0.5549757 0.37495112 0.7202708; -0.077940784 -0.15720427 -0.78963757 -0.5491844; … ; -0.51308984 0.3017166 0.15495385 0.14468096; 0.14140123 -0.09874973 -0.19587937 -0.65763605], bias = Float32[0.1948654, -0.34122184, 0.4598687, -0.018710908, -0.35452545, -0.4962406, -0.13078839, 0.36091018, 0.2568988, 0.14056459  …  0.31392375, -0.3673283, 0.40463546, -0.47771916, -0.34917292, 0.28029904, 0.3620883, 0.18477328, -0.0870033, -0.019972371]), layer_2 = (weight = Float32[0.12802568 0.20404558 … 0.07942832 0.21769916; 0.05321358 0.22305222 … 0.18085112 0.18896061; … ; 0.1121883 -0.005852232 … -0.111335345 -0.16036643; -0.0899377 -0.12978429 … 0.06828076 -0.12703672], bias = Float32[-0.112008914, -0.034107566, 0.10105312, -0.09007942, 0.09070389, -0.010258551, -0.07810972, 0.069268174, -0.011582746, 0.055399004  …  0.016453253, -0.024214584, -0.1099872, -0.019021593, -0.032520182, -0.04116842, 0.0011478653, 0.0020279842, -0.10336012, 0.0008403931]), layer_3 = (weight = Float32[-0.31376415 0.047627546 … -0.2624018 -0.26678166; -0.26169744 0.072574586 … -0.15284705 0.10813269; 0.1049303 -0.29908773 … 0.11520245 0.2767296; -0.10676504 0.10163139 … -0.026285624 0.09962962], bias = Float32[-0.07811634, 0.13905728, -0.08230431, 0.08307263])), (layer_1 = (weight = Float32[-0.7861062 0.5581502 0.3710307 0.7240585; -0.07489895 -0.15946546 -0.78637594 -0.55230254; … ; -0.515958 0.3042903 0.15195285 0.14761999; 0.14602898 -0.10272893 -0.19125065 -0.6621619], bias = Float32[0.19898708, -0.34476656, 0.4550543, -0.023384739, -0.35033512, -0.498381, -0.12722443, 0.35678074, 0.26090553, 0.13830854  …  0.30957782, -0.37147877, 0.4086851, -0.48124382, -0.3520195, 0.28400025, 0.35900882, 0.18115163, -0.08385005, -0.024922695]), layer_2 = (weight = Float32[0.12473722 0.20820393 … 0.076271966 0.22137818; 0.049699772 0.22735754 … 0.17745769 0.19213648; … ; 0.10934157 -0.00292478 … -0.114120714 -0.15731953; -0.08768775 -0.13130593 … 0.07054728 -0.12936868], bias = Float32[-0.116098024, -0.038429055, 0.09706401, -0.08624281, 0.08677984, -0.006041338, -0.07897367, 0.07321769, -0.014644456, 0.057288878  …  0.019264711, -0.020284362, -0.11208828, -0.02233481, -0.035821836, -0.045342103, -0.002547843, 0.0060935616, -0.10655742, 0.002526204]), layer_3 = (weight = Float32[-0.3174117 0.044088315 … -0.2604195 -0.270527; -0.25897893 0.07525056 … -0.15485258 0.110911235; 0.100432225 -0.30352375 … 0.11699138 0.27240297; -0.103192404 0.10510502 … -0.02842101 0.103227176], bias = Float32[-0.07449646, 0.13626906, -0.07777412, 0.07955015])), (layer_1 = (weight = Float32[-0.7895461 0.56102806 0.36735559 0.7276191; -0.07163592 -0.16191429 -0.7829071 -0.55567515; … ; -0.51856357 0.30657375 0.14921995 0.15030552; 0.15015545 -0.10622231 -0.18713075 -0.6661672], bias = Float32[0.20287816, -0.34847602, 0.450342, -0.027807888, -0.34590483, -0.50019544, -0.124045365, 0.35249966, 0.26451027, 0.1364134  …  0.30553854, -0.3758451, 0.4124585, -0.48432583, -0.35458022, 0.287472, 0.35625482, 0.17806411, -0.080924205, -0.029441295]), layer_2 = (weight = Float32[0.12163746 0.21256724 … 0.0733519 0.22497767; 0.046946913 0.23091725 … 0.1748093 0.19426478; … ; 0.10674489 -0.00012070895 … -0.116649404 -0.15449376; -0.08580994 -0.13209864 … 0.072474346 -0.13127813], bias = Float32[-0.12023462, -0.042123877, 0.09336217, -0.08276929, 0.08314633, -0.0017475043, -0.07952675, 0.076961964, -0.017359897, 0.059187874  …  0.021193268, -0.016577538, -0.113859594, -0.024855513, -0.038422607, -0.04927784, -0.005956726, 0.010064345, -0.10955334, 0.0035783492]), layer_3 = (weight = Float32[-0.32085797 0.040760055 … -0.2588373 -0.27403632; -0.25675762 0.07745437 … -0.15661475 0.11316918; 0.09647894 -0.30742297 … 0.118443884 0.268688; -0.099787794 0.1084 … -0.030130204 0.1066255], bias = Float32[-0.07109056, 0.13396277, -0.07377221, 0.07621078])), (layer_1 = (weight = Float32[-0.7927525 0.5636592 0.36390978 0.7309674; -0.06824794 -0.16448542 -0.77932394 -0.55920476; … ; -0.5209368 0.30861428 0.14672852 0.15276703; 0.15359157 -0.10931408 -0.18359366 -0.6694151], bias = Float32[0.20653158, -0.35226837, 0.44574496, -0.031982314, -0.3416363, -0.501788, -0.12114059, 0.34845567, 0.26780483, 0.1348598  …  0.30220857, -0.3804188, 0.41597092, -0.48707798, -0.35707378, 0.29071638, 0.3537934, 0.1747486, -0.0782031, -0.032907467]), layer_2 = (weight = Float32[0.11878177 0.21691671 … 0.070697054 0.22840019; 0.044836 0.23360927 … 0.17277007 0.19551435; … ; 0.10437736 0.0025547706 … -0.11894586 -0.15187997; -0.08426772 -0.132114 … 0.07410493 -0.1327698], bias = Float32[-0.12420989, -0.04518004, 0.09054291, -0.07952078, 0.07976086, 0.0019068341, -0.07997226, 0.08055008, -0.019802315, 0.061645206  …  0.02268656, -0.013120047, -0.115430854, -0.026447643, -0.040531185, -0.052823886, -0.009151464, 0.013912071, -0.112341285, 0.0040274835]), layer_3 = (weight = Float32[-0.3240967 0.037643626 … -0.25756738 -0.27731013; -0.2549241 0.07928972 … -0.15819074 0.11502431; 0.09320356 -0.3106774 … 0.11983449 0.26561943; -0.09642276 0.111640155 … -0.031406496 0.10994009], bias = Float32[-0.06790516, 0.13203105, -0.070427075, 0.072923645])), (layer_1 = (weight = Float32[-0.79574156 0.5660798 0.3606759 0.7341131; -0.06479695 -0.16713375 -0.775685 -0.5628255; … ; -0.52309865 0.31044638 0.14445823 0.1550231; 0.15631819 -0.11208066 -0.18056786 -0.6717888], bias = Float32[0.20994417, -0.35609704, 0.4412766, -0.035929266, -0.33815476, -0.5032171, -0.11840439, 0.34509158, 0.27089164, 0.13363402  …  0.29988706, -0.38515067, 0.4192472, -0.48960632, -0.35963774, 0.2937368, 0.3515977, 0.17095149, -0.07567599, -0.034928627]), layer_2 = (weight = Float32[0.11619356 0.22092398 … 0.06829985 0.23156938; 0.04324155 0.23551138 … 0.17121628 0.19606708; … ; 0.10222324 0.0050734906 … -0.12102939 -0.1494741; -0.0830229 -0.13143495 … 0.075476505 -0.133873], bias = Float32[-0.12782104, -0.047687534, 0.08897833, -0.076389425, 0.07659159, 0.0041352334, -0.08042742, 0.084016606, -0.022028724, 0.06474154  …  0.024114875, -0.009929766, -0.11688118, -0.027138997, -0.04234918, -0.055867586, -0.012177843, 0.017606178, -0.11491193, 0.0039530206]), layer_3 = (weight = Float32[-0.32711828 0.034743562 … -0.25653687 -0.28035188; -0.25338203 0.08084589 … -0.15960678 0.116571866; 0.09062246 -0.31328544 … 0.121306784 0.263132; -0.09302914 0.11489253 … -0.032247134 0.11322661], bias = Float32[-0.064948715, 0.13037996, -0.06775296, 0.06961754])), (layer_1 = (weight = Float32[-0.7985383 0.5683205 0.3576271 0.737076; -0.06131981 -0.1698291 -0.7720229 -0.5664928; … ; -0.5250728 0.31209943 0.14238362 0.15709467; 0.1584856 -0.114602886 -0.17791773 -0.67344135], bias = Float32[0.21313144, -0.35993537, 0.436965, -0.039643142, -0.33556855, -0.50449854, -0.11579506, 0.34253487, 0.27382255, 0.1326872  …  0.29855743, -0.38998726, 0.422314, -0.49196604, -0.36228624, 0.2965499, 0.3496336, 0.16680568, -0.07332828, -0.035726447]), layer_2 = (weight = Float32[0.11385885 0.2243997 … 0.06613135 0.23447853; 0.042067643 0.23668824 … 0.17005517 0.19604148; … ; 0.10026246 0.0074253213 … -0.122922055 -0.14726232; -0.08203131 -0.13025206 … 0.076623924 -0.13465104], bias = Float32[-0.13098954, -0.049729142, 0.088459, -0.07335011, 0.07364477, 0.0050802696, -0.08091746, 0.0873696, -0.024069471, 0.068300016  …  0.025580436, -0.0069893757, -0.118236385, -0.027119737, -0.043967295, -0.05841312, -0.015030997, 0.021141376, -0.11727281, 0.0034768782]), layer_3 = (weight = Float32[-0.32992962 0.032049794 … -0.25569668 -0.28317922; -0.2520721 0.08217948 … -0.16088562 0.117873885; 0.08864197 -0.31534067 … 0.12288206 0.26111114; -0.089623116 0.1181456 … -0.032703552 0.11647462], bias = Float32[-0.062212717, 0.1289515, -0.06565841, 0.06630551]))  …  (layer_1 = (weight = Float32[-0.8425154 0.47209346 0.4262888 0.7217816; -0.1832003 0.047218647 -0.962523 -0.35235357; … ; -0.5939851 0.17652166 0.15409958 0.17145; 0.3380191 0.2874551 -0.5194925 -0.6752922], bias = Float32[0.20582238, -0.6503386, 0.5789662, -0.31104192, -0.13403398, -0.60990554, 0.105613306, 0.18083678, 0.41252893, 0.13008697  …  0.26961076, -0.42284358, 0.26242468, -0.48932496, -0.24052006, -0.3589664, 0.2949812, 0.26432073, -0.11567488, 0.05036465]), layer_2 = (weight = Float32[0.056675777 0.56482655 … 0.045017213 0.5533752; 0.12783337 -0.108699545 … 0.13484874 -0.07524599; … ; 0.13097066 0.13904385 … -0.07265607 -0.13481057; 0.1804065 0.5087473 … 0.5002913 -0.21141255], bias = Float32[-0.266675, 0.11704979, -0.008544433, -0.11307532, 0.06071536, -0.024754785, -0.101410955, 0.11837238, -0.06986591, 0.18118371  …  -0.060763687, 0.038008634, -0.1937314, 0.03566476, 0.11558485, -0.11416286, -0.103505336, 0.06108386, -0.18995403, -0.034759574]), layer_3 = (weight = Float32[-0.3349216 -0.014001183 … -0.17344637 -0.3866775; -0.4338676 0.09212331 … -0.24300706 0.23325579; 0.26086378 -0.48789045 … 0.2609942 0.1609167; 0.07370992 0.15668656 … 0.01541595 0.18122868], bias = Float32[-0.014426909, 0.111894615, -0.1377133, -0.0033461594])), (layer_1 = (weight = Float32[-0.84252924 0.47208098 0.42629468 0.72179496; -0.18263434 0.047092732 -0.96237004 -0.35250998; … ; -0.5939856 0.17646368 0.15416528 0.17142676; 0.3380807 0.2882798 -0.519705 -0.6750831], bias = Float32[0.20582533, -0.65071166, 0.5789136, -0.3110854, -0.13407344, -0.61007833, 0.10569476, 0.18080571, 0.4125509, 0.13034563  …  0.2697328, -0.42293796, 0.26235282, -0.48936242, -0.24063717, -0.3595618, 0.29496288, 0.26432392, -0.11627118, 0.05143129]), layer_2 = (weight = Float32[0.05647773 0.5639307 … 0.044759624 0.5533944; 0.12784515 -0.10864811 … 0.13485952 -0.075270556; … ; 0.13097674 0.13815756 … -0.07265284 -0.1347913; 0.18040249 0.5087533 … 0.5002884 -0.21139866], bias = Float32[-0.26672953, 0.117102146, -0.0085039595, -0.11268365, 0.06075903, -0.024754807, -0.10162593, 0.1183557, -0.06985215, 0.18118855  …  -0.06145902, 0.03808312, -0.19381227, 0.03559196, 0.11559477, -0.11425384, -0.10351865, 0.061082277, -0.18990159, -0.03476629]), layer_3 = (weight = Float32[-0.33485857 -0.013961114 … -0.17345656 -0.38663778; -0.4338507 0.09211887 … -0.2428743 0.23325819; 0.26093134 -0.4878358 … 0.26103014 0.160937; 0.07376637 0.1567187 … 0.015383452 0.18125862], bias = Float32[-0.014341298, 0.11172521, -0.13770458, -0.0032761658])), (layer_1 = (weight = Float32[-0.84254366 0.47206914 0.42629957 0.7218085; -0.18207328 0.046970323 -0.9622195 -0.3526634; … ; -0.5939857 0.17640626 0.15423246 0.17140314; 0.33813563 0.28910995 -0.51993424 -0.67486197], bias = Float32[0.2058303, -0.6510765, 0.57885885, -0.31112728, -0.13411132, -0.6102434, 0.105780415, 0.18077426, 0.4125735, 0.1306012  …  0.269851, -0.42302817, 0.2622849, -0.4893987, -0.24075705, -0.360145, 0.29494599, 0.2643249, -0.116867945, 0.052489914]), layer_2 = (weight = Float32[0.056282796 0.5630352 … 0.04450305 0.55341667; 0.12785567 -0.108597755 … 0.13486956 -0.075295985; … ; 0.13098364 0.13728485 … -0.07264914 -0.13477111; 0.180399 0.5087594 … 0.50028586 -0.21138234], bias = Float32[-0.26678094, 0.11714059, -0.008464717, -0.11229839, 0.060800277, -0.024754964, -0.10183152, 0.118346885, -0.06983893, 0.18119326  …  -0.062139157, 0.0381589, -0.19388993, 0.035520833, 0.11560324, -0.11434315, -0.10353113, 0.061086368, -0.1898417, -0.034771338]), layer_3 = (weight = Float32[-0.3347972 -0.013921924 … -0.17346977 -0.3865983; -0.43383443 0.092114456 … -0.2427435 0.23326007; 0.2610007 -0.48777974 … 0.2610654 0.16095927; 0.07382221 0.15675068 … 0.01535149 0.18128864], bias = Float32[-0.014257645, 0.111559935, -0.1376952, -0.0032076263])), (layer_1 = (weight = Float32[-0.8425587 0.4720579 0.42630354 0.7218222; -0.18151814 0.046851348 -0.9620713 -0.35281372; … ; -0.5939853 0.17634942 0.15430084 0.17137922; 0.33818486 0.28994414 -0.52017874 -0.6746301], bias = Float32[0.20583728, -0.6514331, 0.5788018, -0.31116742, -0.13414769, -0.610401, 0.10587019, 0.18074228, 0.4125968, 0.13085368  …  0.26996508, -0.42311436, 0.26222104, -0.4894338, -0.24087925, -0.36071682, 0.29493046, 0.26432374, -0.11746614, 0.053541128]), layer_2 = (weight = Float32[0.056090865 0.56214195 … 0.044247657 0.5534418; 0.12786494 -0.10854855 … 0.13487887 -0.075322114; … ; 0.13099137 0.13642934 … -0.072644964 -0.13475047; 0.18039607 0.50876564 … 0.5002836 -0.2113638], bias = Float32[-0.2668295, 0.11716591, -0.008426658, -0.11191793, 0.06083925, -0.024755133, -0.10202772, 0.11834713, -0.06982617, 0.18119769  …  -0.062803306, 0.0382359, -0.1939643, 0.0354513, 0.11561046, -0.11443082, -0.10354277, 0.061095797, -0.18977493, -0.034774657]), layer_3 = (weight = Float32[-0.33473742 -0.013883594 … -0.17348564 -0.38655913; -0.43381864 0.09211015 … -0.24261439 0.23326145; 0.26107165 -0.4877224 … 0.26109987 0.16098347; 0.073877454 0.1567825 … 0.015320176 0.18131873], bias = Float32[-0.014175927, 0.11139882, -0.13768506, -0.0031404924])), (layer_1 = (weight = Float32[-0.84257436 0.47204724 0.42630664 0.7218361; -0.18096979 0.046735737 -0.96192557 -0.35296088; … ; -0.5939846 0.17629315 0.15437017 0.17135508; 0.33822924 0.29078102 -0.5204371 -0.6743888], bias = Float32[0.20584622, -0.65178156, 0.57874244, -0.3112057, -0.13418265, -0.61055136, 0.10596399, 0.18070966, 0.412621, 0.13110308  …  0.27007484, -0.42319676, 0.2621613, -0.48946783, -0.2410034, -0.3612781, 0.29491624, 0.26432052, -0.11806662, 0.05458534]), layer_2 = (weight = Float32[0.05590184 0.56125283 … 0.043993607 0.5534695; 0.12787297 -0.10850055 … 0.13488744 -0.07534877; … ; 0.13099992 0.13559484 … -0.07264028 -0.13472989; 0.18039371 0.5087721 … 0.50028163 -0.21134326], bias = Float32[-0.26687542, 0.1171789, -0.0083897365, -0.11154086, 0.060876112, -0.024755191, -0.102214575, 0.1183575, -0.069813766, 0.18120174  …  -0.0634508, 0.038314085, -0.1940353, 0.03538328, 0.11561662, -0.11451688, -0.10355357, 0.06111032, -0.18970188, -0.034776203]), layer_3 = (weight = Float32[-0.33467916 -0.013846104 … -0.17350388 -0.38652036; -0.43380317 0.09210601 … -0.24248666 0.23326236; 0.26114407 -0.4876639 … 0.26113352 0.1610095; 0.07393213 0.15681416 … 0.015289619 0.18134886], bias = Float32[-0.0140961, 0.1112419, -0.13767414, -0.003074711])), (layer_1 = (weight = Float32[-0.84259063 0.47203708 0.4263089 0.72185016; -0.18042898 0.046623405 -0.96178234 -0.35310486; … ; -0.59398365 0.17623746 0.15444027 0.17133081; 0.33826956 0.2916193 -0.520708 -0.67413926], bias = Float32[0.2058571, -0.65212184, 0.5786807, -0.31124195, -0.13421625, -0.61069477, 0.10606174, 0.18067628, 0.41264623, 0.13134943  …  0.27018014, -0.42327556, 0.26210582, -0.48950088, -0.2411291, -0.36182955, 0.29490328, 0.2643154, -0.11867011, 0.055622783]), layer_2 = (weight = Float32[0.055715635 0.5603697 … 0.04374109 0.5534994; 0.12787981 -0.10845382 … 0.13489528 -0.075375795; … ; 0.13100931 0.13478513 … -0.072635055 -0.13470986; 0.18039197 0.50877875 … 0.50028 -0.21132086], bias = Float32[-0.266919, 0.117180385, -0.008353912, -0.11116602, 0.06091103, -0.024755016, -0.102392145, 0.118378915, -0.06980161, 0.1812053  …  -0.06408107, 0.038393464, -0.1941029, 0.035316695, 0.11562194, -0.11460135, -0.103563555, 0.061129775, -0.18962319, -0.034775946]), layer_3 = (weight = Float32[-0.33462235 -0.013809432 … -0.17352422 -0.38648203; -0.4337879 0.092102095 … -0.24236009 0.2332628; 0.26121777 -0.48760432 … 0.26116627 0.16103731; 0.07398627 0.15684567 … 0.015259915 0.181379], bias = Float32[-0.014018102, 0.11108918, -0.13766237, -0.0030102248])), (layer_1 = (weight = Float32[-0.84260756 0.4720274 0.4263104 0.7218644; -0.17989634 0.046514258 -0.96164155 -0.35324565; … ; -0.59398246 0.17618237 0.1545109 0.17130645; 0.33830652 0.29245785 -0.5209902 -0.6738825], bias = Float32[0.20586991, -0.65245414, 0.5786166, -0.3112762, -0.1342486, -0.61083156, 0.10616335, 0.18064205, 0.41267264, 0.13159275  …  0.27028087, -0.42335093, 0.26205465, -0.48953304, -0.24125606, -0.36237195, 0.29489148, 0.26430857, -0.11927724, 0.05665352]), layer_2 = (weight = Float32[0.055532172 0.55949426 … 0.04349028 0.5535313; 0.12788549 -0.108408414 … 0.13490239 -0.07540301; … ; 0.13101952 0.13400385 … -0.07262928 -0.13469084; 0.18039085 0.5087857 … 0.5002788 -0.2112968], bias = Float32[-0.2669604, 0.11717116, -0.00831915, -0.11079246, 0.060944177, -0.024754496, -0.10256054, 0.11841212, -0.06978962, 0.18120831  …  -0.06469371, 0.038474057, -0.19416703, 0.03525147, 0.11562663, -0.114684284, -0.10357273, 0.06115405, -0.18953945, -0.034773875]), layer_3 = (weight = Float32[-0.33456698 -0.013773556 … -0.17354645 -0.3864442; -0.43377277 0.09209845 … -0.2422345 0.23326282; 0.26129258 -0.48754382 … 0.26119807 0.16106682; 0.07403988 0.15687706 … 0.015231154 0.18140914], bias = Float32[-0.013941857, 0.11094066, -0.13764971, -0.002946975])), (layer_1 = (weight = Float32[-0.84262514 0.47201815 0.42631114 0.7218788; -0.17937234 0.046408202 -0.96150327 -0.35338327; … ; -0.59398115 0.17612788 0.1545819 0.17128205; 0.33834067 0.29329562 -0.5212825 -0.67361945], bias = Float32[0.20588465, -0.65277857, 0.57855016, -0.31130832, -0.13427977, -0.610962, 0.10626872, 0.1806069, 0.41270036, 0.13183308  …  0.270377, -0.42342308, 0.26200783, -0.4895644, -0.24138398, -0.36290598, 0.29488075, 0.26430026, -0.119888455, 0.057677474]), layer_2 = (weight = Float32[0.055351388 0.55862826 … 0.043241356 0.5535649; 0.12789004 -0.108364396 … 0.13490877 -0.07543026; … ; 0.13103054 0.1332544 … -0.07262295 -0.13467325; 0.18039036 0.508793 … 0.5002779 -0.21127123], bias = Float32[-0.2669999, 0.11715202, -0.008285414, -0.1104194, 0.060975723, -0.024753528, -0.10271992, 0.11845773, -0.06977769, 0.1812107  …  -0.06528842, 0.03855591, -0.19422771, 0.035187524, 0.11563092, -0.11476571, -0.10358113, 0.061183076, -0.18945128, -0.034769997]), layer_3 = (weight = Float32[-0.33451295 -0.0137384515 … -0.17357036 -0.38640693; -0.43375766 0.09209512 … -0.24210975 0.2332624; 0.26136833 -0.4874825 … 0.26122886 0.16109793; 0.07409298 0.1569083 … 0.0152034145 0.18143922], bias = Float32[-0.01386728, 0.110796325, -0.13763615, -0.0028849011])), (layer_1 = (weight = Float32[-0.8426434 0.4720093 0.42631117 0.7218934; -0.17885733 0.04630513 -0.9613674 -0.35351777; … ; -0.5939797 0.17607398 0.15465313 0.17125767; 0.33837253 0.29413173 -0.521584 -0.67335093], bias = Float32[0.2059013, -0.65309536, 0.5784814, -0.31133834, -0.13430987, -0.61108637, 0.10637774, 0.18057078, 0.41272947, 0.13207048  …  0.2704686, -0.42349222, 0.2619654, -0.48959506, -0.24151264, -0.36343223, 0.294871, 0.26429066, -0.12050411, 0.05869446]), layer_2 = (weight = Float32[0.05517322 0.55777323 … 0.042994484 0.55359995; 0.12789348 -0.108321816 … 0.13491443 -0.07545739; … ; 0.13104239 0.1325398 … -0.07261603 -0.13465749; 0.1803905 0.5088006 … 0.5002774 -0.21124429], bias = Float32[-0.26703766, 0.11712374, -0.0082526775, -0.11004626, 0.06100583, -0.024752017, -0.102870464, 0.118516155, -0.06976574, 0.1812124  …  -0.06586503, 0.03863907, -0.19428493, 0.035124794, 0.115635, -0.11484568, -0.103588775, 0.061216824, -0.18935923, -0.034764327]), layer_3 = (weight = Float32[-0.33446026 -0.013704095 … -0.1735958 -0.3863702; -0.43374252 0.09209213 … -0.24198572 0.23326157; 0.26144487 -0.48742047 … 0.2612586 0.16113059; 0.07414557 0.15693942 … 0.015176766 0.18146926], bias = Float32[-0.0137942815, 0.110656165, -0.13762169, -0.002823944])), (layer_1 = (weight = Float32[-0.8646588 0.49752995 0.39984375 0.7443921; -0.12747608 0.034922883 -0.94088745 -0.3890372; … ; -0.6110231 0.21088429 0.1416041 0.18473044; 0.29769173 0.31554145 -0.55249757 -0.6282835], bias = Float32[0.22948335, -0.65948206, 0.5695705, -0.32701632, -0.14653566, -0.6276898, 0.13503374, 0.1670128, 0.4384215, 0.134501  …  0.26643053, -0.4438132, 0.29045627, -0.5158026, -0.23445924, -0.26908663, 0.31886137, 0.2757164, 0.07960184, -0.14310713]), layer_2 = (weight = Float32[0.037270233 0.36452237 … 0.017393485 0.5650064; 0.15091862 -0.08929501 … 0.15813234 -0.056013606; … ; 0.10709002 0.15214895 … -0.09650536 -0.14895537; 0.20088336 0.53337485 … 0.5207664 -0.2205419], bias = Float32[-0.26288947, 0.12399343, 0.009231143, 0.0632344, 0.0790139, -0.0107539445, -0.08985133, 0.23966151, -0.05921632, 0.17936991  …  -0.059383392, 0.058192983, -0.18699138, 0.02726415, 0.131886, -0.12809463, -0.0847442, 0.10093904, -0.20886137, -0.016853908]), layer_3 = (weight = Float32[-0.32316336 0.007957895 … -0.17228039 -0.3631168; -0.4284567 0.11545352 … -0.16096818 0.25497854; 0.2743785 -0.4609105 … 0.25434622 0.18445309; 0.08598345 0.17958412 … 0.026900865 0.20400114], bias = Float32[-0.00464844, 0.103954785, -0.1301573, 0.0037584899]))], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), Float32[192.39883, 135.20009, 113.53052, 102.90518, 96.61792, 92.35046, 89.20726, 86.74313, 84.62893, 82.73641  …  19.981504, 19.835167, 19.692476, 19.553118, 19.416779, 19.283154, 19.151941, 19.022848, 18.895592, 8.750851])

We can now animate the training to get a better understanding of the fit.

function predict(y0, t, p, state)
    node, _, _ = neural_ode(t, length(y0))
    Array(node(y0, p, state)[1])
end

function plot_pred(t_train, y_train, t_grid, rescale_t, rescale_y,
        num_iters, p, state, loss, y0 = y_train[:, 1])
    y_pred = predict(y0, t_grid, p, state)
    return plot_result(rescale_t(t_train), rescale_y(y_train),
        rescale_t(t_grid), rescale_y(y_pred), loss, num_iters)
end

function plot_pred(t, y, y_pred)
    plt = Plots.scatter(t, y; label = "Observation")
    Plots.plot!(plt, t, y_pred; label = "Prediction")
end

function plot_pred(t, y, t_pred, y_pred; kwargs...)
    plot_params = zip(eachrow(y), eachrow(y_pred), FEATURE_NAMES, UNITS)
    map(enumerate(plot_params)) do (i, (yᵢ, ŷᵢ, name, unit))
        plt = Plots.plot(t_pred, ŷᵢ; label = "Prediction", color = i,
            linewidth = 3, legend = nothing, title = name, kwargs...)
        Plots.scatter!(plt, t, yᵢ; label = "Observation", xlabel = "Time",
            ylabel = unit, markersize = 5, color = i)
    end
end

function plot_result(t, y, t_pred, y_pred, loss, num_iters; kwargs...)
    plts_preds = plot_pred(t, y, t_pred, y_pred; kwargs...)
    plot!(plts_preds[1]; ylim = (10, 40), legend = (0.65, 1.0))
    plot!(plts_preds[2]; ylim = (20, 100))
    plot!(plts_preds[3]; ylim = (2, 12))
    plot!(plts_preds[4]; ylim = (990, 1025))

    p_loss = Plots.plot(loss; label = nothing, linewidth = 3, title = "Loss",
        xlabel = "Iterations", xlim = (0, num_iters))
    plots = [plts_preds..., p_loss]
    plot(plots...; layout = grid(length(plots), 1), size = (900, 900))
end

function animate_training(
        plot_frame, t_train, y_train, ps, losses, obs_grid; pause_for = 300)
    obs_count = Dict(i - 1 => n for (i, n) in enumerate(obs_grid))
    is = [min(i, length(losses)) for i in 2:(length(losses) + pause_for)]
    @animate for i in is
        stage = Int(floor((i - 1) / length(losses) * length(obs_grid)))
        k = obs_count[stage]
        plot_frame(t_train[1:k], y_train[:, 1:k], ps[i], losses[1:i])
    end every 2
end

num_iters = length(losses)
t_train_grid = collect(range(extrema(t_train)...; length = 500))
rescale_t(x) = t_scale .* x .+ t_mean
rescale_y(x) = y_scale .* x .+ y_mean
function plot_frame(t, y, p, loss)
    plot_pred(t, y, t_train_grid, rescale_t, rescale_y, num_iters, p, state, loss)
end
anim = animate_training(plot_frame, t_train, y_train, ps, losses, obs_grid)
gif(anim, "node_weather_forecast_training.gif")
Example block output

Looks good! But how well does the model forecast?

function plot_extrapolation(t_train, y_train, t_test, y_test, t̂, ŷ)
    plts = plot_pred(t_train, y_train, t̂, ŷ)
    for (i, (plt, y)) in enumerate(zip(plts, eachrow(y_test)))
        scatter!(plt, t_test, y; color = i, markerstrokecolor = :white,
            label = "Test observation")
    end

    plot!(plts[1]; ylim = (10, 40), legend = :topleft)
    plot!(plts[2]; ylim = (20, 100))
    plot!(plts[3]; ylim = (2, 12))
    plot!(plts[4]; ylim = (990, 1025))
    plot(plts...; layout = grid(length(plts), 1), size = (900, 900))
end

t_grid = collect(range(minimum(t_train), maximum(t_test); length = 500))
y_pred = predict(y_train[:, 1], t_grid, ps[end], state)
plot_extrapolation(rescale_t(t_train), rescale_y(y_train), rescale_t(t_test),
    rescale_y(y_test), rescale_t(t_grid), rescale_y(y_pred))
Example block output

While there is some drift in the weather patterns, the model extrapolates very well!