Weather forecasting with neural ODEs

In this example we are going to apply neural ODEs to a multidimensional weather dataset and use it for weather forecasting. This example is adapted from Forecasting the weather with neural ODEs - Sebatian Callh personal blog.

The data

The data is a four-dimensional dataset of daily temperature, humidity, wind speed and pressure measured over four years in the city Delhi. Let us download and plot it.

using Random, Dates, Optimization, ComponentArrays, Lux, OptimizationOptimisers, DiffEqFlux,
      OrdinaryDiffEq, CSV, DataFrames, Dates, Statistics, Plots, DataDeps

function download_data(
        data_url = "https://raw.githubusercontent.com/SebastianCallh/neural-ode-weather-forecast/master/data/",
        data_local_path = "./delhi")
    function load(file_name)
        data_dep = DataDep("delhi/train", "", "$data_url/$file_name")
        Base.download(data_dep, data_local_path; i_accept_the_terms_of_use = true)
        CSV.read(joinpath(data_local_path, file_name), DataFrame)
    end

    train_df = load("DailyDelhiClimateTrain.csv")
    test_df = load("DailyDelhiClimateTest.csv")
    return vcat(train_df, test_df)
end

df = download_data()
5×5 DataFrame
Rowdatemeantemphumiditywind_speedmeanpressure
DateFloat64Float64Float64Float64
12013-01-0110.084.50.01015.67
22013-01-027.492.02.981017.8
32013-01-037.1666787.04.633331018.67
42013-01-048.6666771.33331.233331017.17
52013-01-056.086.83333.71016.5
FEATURES = [:meantemp, :humidity, :wind_speed, :meanpressure]
UNITS = ["Celsius", "g/m³ of water", "km/h", "hPa"]
FEATURE_NAMES = ["Mean temperature", "Humidity", "Wind speed", "Mean pressure"]

function plot_data(df)
    plots = map(enumerate(zip(FEATURES, FEATURE_NAMES, UNITS))) do (i, (f, n, u))
        plot(df[:, :date], df[:, f]; title = n, label = nothing,
            ylabel = u, size = (800, 600), color = i)
    end

    n = length(plots)
    plot(plots...; layout = (Int(n / 2), Int(n / 2)))
end

plot_data(df)
Example block output

The data show clear annual behaviour (it is difficult to see for pressure due to wild measurement errors but the pattern is there). It is concievable that this system can be described with an ODE, but which? Let us use an network to learn the dynamics from the dataset. Training neural networks is easier with standardised data so we will compute standardised features before training. Finally, we take the first 20 days for training and the rest for testing.

function standardize(x)
    μ = mean(x; dims = 2)
    σ = std(x; dims = 2)
    z = (x .- μ) ./ σ
    return z, μ, σ
end

function featurize(raw_df, num_train = 20)
    raw_df.year = Float64.(year.(raw_df.date))
    raw_df.month = Float64.(month.(raw_df.date))
    df = combine(groupby(raw_df, [:year, :month]),
        :date => (d -> mean(year.(d)) .+ mean(month.(d)) ./ 12),
        :meantemp => mean, :humidity => mean, :wind_speed => mean,
        :meanpressure => mean; renamecols = false)
    t_and_y(df) = df.date', Matrix(select(df, FEATURES))'
    t_train, y_train = t_and_y(df[1:num_train, :])
    t_test, y_test = t_and_y(df[(num_train + 1):end, :])
    t_train, t_mean, t_scale = standardize(t_train)
    y_train, y_mean, y_scale = standardize(y_train)
    t_test = (t_test .- t_mean) ./ t_scale
    y_test = (y_test .- y_mean) ./ y_scale

    return (
        vec(t_train), y_train, vec(t_test), y_test, (t_mean, t_scale), (y_mean, y_scale))
end

function plot_features(t_train, y_train, t_test, y_test)
    plt_split = plot(reshape(t_train, :), y_train'; linewidth = 3, colors = 1:4,
        xlabel = "Normalized time", ylabel = "Normalized values",
        label = nothing, title = "Features")
    plot!(plt_split, reshape(t_test, :), y_test'; linewidth = 3,
        linestyle = :dash, color = [1 2 3 4], label = nothing)
    plot!(plt_split, [0], [0]; linewidth = 0, label = "Train", color = 1)
    plot!(plt_split, [0], [0]; linewidth = 0, linestyle = :dash,
        label = "Test", color = 1, ylims = (-5, 5))
end

t_train, y_train, t_test, y_test, (t_mean, t_scale), (y_mean, y_scale) = featurize(df)
plot_features(t_train, y_train, t_test, y_test)
Example block output

The dataset is now centered around 0 with a standard deviation of 1. We will ignore the extreme pressure measurements for simplicity. Since they are in the test split they won't impact training anyway. We are now ready to construct and train our model! To avoid local minimas we will train iteratively with increasing amounts of data.

function neural_ode(t, data_dim)
    f = Chain(Dense(data_dim => 64, swish), Dense(64 => 32, swish), Dense(32 => data_dim))

    node = NeuralODE(f, extrema(t), Tsit5(); saveat = t, abstol = 1e-9, reltol = 1e-9)

    rng = Xoshiro(0)
    p, state = Lux.setup(rng, f)

    return node, ComponentArray(p), state
end

function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kwargs...)
    predict(p) = Array(node(y0, p, state)[1])
    loss(p) = sum(abs2, predict(p) .- y)

    adtype = Optimization.AutoZygote()
    optf = OptimizationFunction((p, _) -> loss(p), adtype)
    optprob = OptimizationProblem(optf, p)
    res = solve(optprob, opt; maxiters = maxiters, kwargs...)
    res.minimizer, state
end

function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...)
    log_results(ps, losses) = (state, loss) -> begin
        push!(ps, copy(state.u))
        push!(losses, loss)
        false
    end

    ps, losses = ComponentArray[], Float32[]
    for k in obs_grid
        node, p_new, state_new = neural_ode(t, size(y, 1))
        p === nothing && (p = p_new)
        state === nothing && (state = state_new)

        p, state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
            maxiters, rng; callback = log_results(ps, losses), kwargs...)
    end
    ps, state, losses
end

rng = MersenneTwister(123)
obs_grid = 4:4:length(t_train) # we train on an increasing amount of the first k obs
maxiters = 150
lr = 5e-3
ps, state, losses = train(t_train, y_train, obs_grid, maxiters, lr, rng; progress = true)
(ComponentArrays.ComponentArray[(layer_1 = (weight = Float32[-0.7645059 0.53793734 0.39337653 0.70219153; -0.07752936 -0.15911278 -0.7899115 -0.54955494; … ; -0.49703807 0.2862404 0.17134601 0.12848014; 0.12155315 -0.079533935 -0.21581413 -0.6378495], bias = Float32[0.17611152, -0.3398568, 0.47946006, 0.0011358261, -0.369452, -0.48125637, -0.14892167, 0.36018282, 0.23786569, 0.15563601  …  0.32014894, -0.34898245, 0.3859067, -0.4593916, -0.33226132, 0.26252872, 0.3789305, 0.20379025, -0.10364276, 4.684925f-5]), layer_2 = (weight = Float32[0.14405736 0.18788297 … 0.09533301 0.20118886; 0.07223731 0.20376958 … 0.1997639 0.17007011; … ; 0.12802728 -0.021456784 … -0.09562197 -0.17657918; -0.104789555 -0.11612058 … 0.053488962 -0.111916974], bias = Float32[-0.09538478, -0.014825672, 0.117425635, -0.10896616, 0.109270275, -0.012288272, -0.06560917, 0.050810307, 0.005339995, 0.039113745  …  -0.0020065606, -0.042473897, -0.09489921, -0.0010632724, -0.013844758, -0.022791013, 0.01943487, -0.016084298, -0.08688739, -0.013523743]), layer_3 = (weight = Float32[-0.29617885 0.06500828 … -0.27692604 -0.2488153; -0.2784669 0.05598557 … -0.13869432 0.09098109; 0.12476622 -0.27931806 … 0.09968187 0.2964681; -0.12457438 0.084029615 … -0.011420269 0.08157834], bias = Float32[-0.0956653, 0.15592338, -0.10216346, 0.10081644])), (layer_1 = (weight = Float32[-0.7695059 0.54293734 0.38837653 0.7071915; -0.082529366 -0.15411279 -0.7949115 -0.54455495; … ; -0.50203806 0.2912404 0.16634601 0.13348013; 0.12655315 -0.08453394 -0.21081413 -0.6428495], bias = Float32[0.18111151, -0.3348568, 0.47446007, -0.0038641738, -0.364452, -0.48625636, -0.14392167, 0.36518282, 0.24286568, 0.15063602  …  0.32514894, -0.35398245, 0.3909067, -0.4643916, -0.33726132, 0.2675287, 0.3739305, 0.19879025, -0.09864276, -0.00495315]), layer_2 = (weight = Float32[0.13905737 0.19288297 … 0.09033301 0.20618886; 0.06723731 0.20876957 … 0.1947639 0.1750701; … ; 0.12302727 -0.016456783 … -0.100621976 -0.17157918; -0.09978955 -0.12112058 … 0.05848896 -0.11691698], bias = Float32[-0.10038478, -0.01982567, 0.11242563, -0.103966154, 0.10427027, -0.017288273, -0.070609175, 0.055810306, 0.00033999514, 0.044113744  …  0.0029934393, -0.0374739, -0.09989921, -0.0060632722, -0.018844757, -0.027791012, 0.014434869, -0.011084299, -0.09188739, -0.008523743]), layer_3 = (weight = Float32[-0.30117884 0.060008284 … -0.27192605 -0.2538153; -0.2734669 0.06098557 … -0.14369431 0.09598109; 0.11976622 -0.28431806 … 0.10468187 0.2914681; -0.119574375 0.08902962 … -0.016420268 0.08657834], bias = Float32[-0.090665296, 0.15092339, -0.09716345, 0.09581644])), (layer_1 = (weight = Float32[-0.77415794 0.54744697 0.3836119 0.7118782; -0.08247053 -0.15390255 -0.7946522 -0.5445939; … ; -0.5062749 0.2953534 0.16202132 0.13775234; 0.13154188 -0.08946695 -0.20580809 -0.6478454], bias = Float32[0.1859277, -0.33550912, 0.46965417, -0.008857113, -0.36131158, -0.4903583, -0.13918692, 0.36736962, 0.24776426, 0.14653559  …  0.3221644, -0.35883018, 0.3957265, -0.46921173, -0.34183937, 0.27211842, 0.3694645, 0.19404387, -0.094268456, -0.009955538]), layer_2 = (weight = Float32[0.13504224 0.19667664 … 0.08633277 0.21026765; 0.062326834 0.21348846 … 0.18984067 0.1799343; … ; 0.11887717 -0.012407767 … -0.104741804 -0.16734572; -0.09581976 -0.12482073 … 0.06242792 -0.12096606], bias = Float32[-0.104377806, -0.024525825, 0.10862695, -0.09903822, 0.099480204, -0.017103026, -0.07427631, 0.06058006, -0.0041402685, 0.04891845  …  0.007930089, -0.032764755, -0.104080245, -0.010771483, -0.023789186, -0.032437388, 0.009640757, -0.0064660273, -0.09617484, -0.004538373]), layer_3 = (weight = Float32[-0.30573422 0.0554925 … -0.26799142 -0.25848314; -0.26893377 0.06547372 … -0.14753108 0.10064726; 0.11475992 -0.28932476 … 0.10918424 0.2864823; -0.11490817 0.09365397 … -0.02045761 0.09133516], bias = Float32[-0.08611815, 0.14637937, -0.09215812, 0.091161616])), (layer_1 = (weight = Float32[-0.7784472 0.5514358 0.37913775 0.71622545; -0.08059039 -0.15524928 -0.792524 -0.54649013; … ; -0.50989956 0.29877883 0.15827593 0.1414288; 0.13651451 -0.09427499 -0.20080206 -0.65280724], bias = Float32[0.19051373, -0.3380075, 0.4647567, -0.013837405, -0.35829666, -0.49362695, -0.13478616, 0.36459637, 0.25248504, 0.14325428  …  0.3182262, -0.36321783, 0.4003107, -0.47369504, -0.345822, 0.2763524, 0.36554664, 0.18921791, -0.090440065, -0.014952682]), layer_2 = (weight = Float32[0.13145112 0.2002193 … 0.082776226 0.21400695; 0.057501853 0.21832521 … 0.18504572 0.18475854; … ; 0.115331784 -0.008962315 … -0.108244404 -0.16367406; -0.092612274 -0.12762693 … 0.06561293 -0.12426508], bias = Float32[-0.10810679, -0.029363409, 0.10495157, -0.094359264, 0.09494208, -0.0140723055, -0.076658204, 0.06506586, -0.008101305, 0.052774176  …  0.012571787, -0.028363753, -0.10738469, -0.0151154855, -0.028451694, -0.036862597, 0.005199345, -0.0021373788, -0.099933036, -0.0014898446]), layer_3 = (weight = Float32[-0.30989394 0.05140445 … -0.26488325 -0.2627834; -0.265002 0.06933495 … -0.15047744 0.10472585; 0.10977174 -0.29429308 … 0.1127075 0.28150427; -0.1106413 0.09784571 … -0.023676567 0.09569541], bias = Float32[-0.081969306, 0.14241463, -0.087164015, 0.086914934])), (layer_1 = (weight = Float32[-0.78241605 0.5549757 0.37495112 0.7202708; -0.07794078 -0.15720427 -0.78963757 -0.5491844; … ; -0.5130899 0.3017166 0.15495384 0.14468096; 0.14140123 -0.09874972 -0.19587936 -0.65763605], bias = Float32[0.1948654, -0.34122184, 0.4598687, -0.018710908, -0.3545254, -0.4962406, -0.13078839, 0.36091018, 0.2568988, 0.14056458  …  0.31392375, -0.3673283, 0.40463546, -0.47771916, -0.34917292, 0.28029904, 0.3620883, 0.18477328, -0.0870033, -0.01997237]), layer_2 = (weight = Float32[0.12802568 0.20404558 … 0.079428315 0.21769917; 0.05321359 0.22305222 … 0.18085112 0.18896061; … ; 0.1121883 -0.005852226 … -0.11133535 -0.16036642; -0.0899377 -0.12978429 … 0.06828076 -0.12703672], bias = Float32[-0.11200892, -0.034107562, 0.10105313, -0.09007941, 0.09070389, -0.010258576, -0.07810972, 0.06926818, -0.011582752, 0.055399027  …  0.016453242, -0.024214579, -0.1099872, -0.01902159, -0.032520175, -0.041168425, 0.001147863, 0.002027993, -0.103360124, 0.0008403966]), layer_3 = (weight = Float32[-0.31376418 0.04762754 … -0.2624018 -0.26678166; -0.26169744 0.072574586 … -0.15284705 0.10813269; 0.1049303 -0.29908773 … 0.11520246 0.2767296; -0.10676503 0.1016314 … -0.02628563 0.09962962], bias = Float32[-0.07811633, 0.13905728, -0.082304314, 0.083072625])), (layer_1 = (weight = Float32[-0.7861062 0.5581502 0.3710307 0.7240585; -0.074898936 -0.15946546 -0.78637594 -0.55230254; … ; -0.5159581 0.3042903 0.15195283 0.14761999; 0.146029 -0.10272893 -0.19125064 -0.6621619], bias = Float32[0.19898708, -0.34476656, 0.4550543, -0.023384739, -0.35033506, -0.498381, -0.12722443, 0.35678074, 0.26090553, 0.13830853  …  0.30957782, -0.37147877, 0.4086851, -0.48124382, -0.3520195, 0.28400025, 0.3590088, 0.18115164, -0.08385004, -0.024922695]), layer_2 = (weight = Float32[0.12473721 0.20820393 … 0.07627195 0.2213782; 0.049699787 0.22735754 … 0.17745769 0.19213648; … ; 0.10934157 -0.0029247701 … -0.11412072 -0.1573195; -0.08768775 -0.13130593 … 0.07054728 -0.12936868], bias = Float32[-0.11609804, -0.038429048, 0.09706402, -0.0862428, 0.08677984, -0.0060413647, -0.07897367, 0.073217705, -0.014644463, 0.057288907  …  0.019264683, -0.020284353, -0.11208828, -0.022334807, -0.035821818, -0.04534211, -0.0025478464, 0.006093576, -0.10655744, 0.002526206]), layer_3 = (weight = Float32[-0.31741172 0.044088304 … -0.2604195 -0.270527; -0.25897893 0.07525056 … -0.15485258 0.110911235; 0.100432225 -0.30352375 … 0.11699139 0.27240297; -0.1031924 0.105105035 … -0.028421018 0.103227176], bias = Float32[-0.07449645, 0.13626906, -0.07777413, 0.07955014])), (layer_1 = (weight = Float32[-0.7895461 0.56102806 0.36735559 0.7276192; -0.0716359 -0.16191429 -0.7829071 -0.55567515; … ; -0.5185636 0.30657375 0.14921992 0.15030554; 0.15015547 -0.10622231 -0.18713072 -0.6661672], bias = Float32[0.20287816, -0.34847602, 0.450342, -0.027807888, -0.34590477, -0.50019544, -0.124045365, 0.35249966, 0.26451027, 0.13641338  …  0.30553854, -0.3758451, 0.41245854, -0.48432583, -0.35458022, 0.287472, 0.3562548, 0.17806415, -0.0809242, -0.0294413]), layer_2 = (weight = Float32[0.121637456 0.21256724 … 0.07335188 0.2249777; 0.04694693 0.23091725 … 0.1748093 0.19426478; … ; 0.106744885 -0.00012069638 … -0.11664941 -0.15449373; -0.08580994 -0.13209864 … 0.072474346 -0.13127813], bias = Float32[-0.120234646, -0.04212387, 0.09336216, -0.08276929, 0.08314633, -0.0017475286, -0.07952675, 0.07696198, -0.017359903, 0.0591879  …  0.021193223, -0.016577527, -0.113859594, -0.024855508, -0.03842258, -0.04927785, -0.00595673, 0.010064365, -0.10955336, 0.0035783513]), layer_3 = (weight = Float32[-0.320858 0.04076004 … -0.2588373 -0.27403635; -0.25675762 0.07745437 … -0.15661477 0.11316918; 0.09647894 -0.30742297 … 0.1184439 0.268688; -0.09978779 0.10840002 … -0.030130213 0.1066255], bias = Float32[-0.07109054, 0.13396277, -0.073772214, 0.076210774])), (layer_1 = (weight = Float32[-0.7927525 0.5636592 0.36390975 0.73096746; -0.068247914 -0.16448542 -0.77932394 -0.55920476; … ; -0.52093685 0.30861428 0.14672849 0.15276705; 0.1535916 -0.10931408 -0.18359363 -0.6694151], bias = Float32[0.20653158, -0.35226837, 0.44574496, -0.031982314, -0.34163624, -0.501788, -0.12114059, 0.34845567, 0.26780483, 0.13485979  …  0.30220854, -0.3804188, 0.41597095, -0.48707798, -0.35707378, 0.29071638, 0.35379338, 0.17474867, -0.07820309, -0.032907486]), layer_2 = (weight = Float32[0.11878176 0.21691671 … 0.07069703 0.22840022; 0.044836022 0.23360927 … 0.17277007 0.19551435; … ; 0.10437735 0.0025547852 … -0.11894587 -0.15187994; -0.08426772 -0.132114 … 0.07410493 -0.1327698], bias = Float32[-0.12420992, -0.045180034, 0.09054289, -0.07952078, 0.07976086, 0.0019068227, -0.07997225, 0.0805501, -0.019802323, 0.06164522  …  0.022686502, -0.013120033, -0.11543085, -0.026447635, -0.040531147, -0.052823905, -0.009151466, 0.013912096, -0.112341315, 0.004027489]), layer_3 = (weight = Float32[-0.32409677 0.037643608 … -0.25756738 -0.27731016; -0.25492412 0.07928972 … -0.15819077 0.11502431; 0.09320355 -0.31067744 … 0.119834505 0.26561943; -0.096422754 0.11164017 … -0.031406507 0.10994009], bias = Float32[-0.06790514, 0.13203105, -0.07042708, 0.07292364])), (layer_1 = (weight = Float32[-0.79574156 0.5660798 0.36067587 0.73411316; -0.064796925 -0.16713375 -0.775685 -0.5628255; … ; -0.5230987 0.31044638 0.1444582 0.15502313; 0.15631822 -0.11208066 -0.18056783 -0.6717888], bias = Float32[0.20994417, -0.35609704, 0.4412766, -0.035929263, -0.3381547, -0.5032171, -0.11840439, 0.34509158, 0.27089164, 0.13363399  …  0.299887, -0.38515067, 0.41924724, -0.48960632, -0.35963774, 0.2937368, 0.35159767, 0.17095158, -0.07567597, -0.03492866]), layer_2 = (weight = Float32[0.11619355 0.22092398 … 0.06829982 0.23156941; 0.04324157 0.23551139 … 0.17121628 0.19606708; … ; 0.10222323 0.0050735055 … -0.1210294 -0.14947407; -0.0830229 -0.13143496 … 0.075476505 -0.133873], bias = Float32[-0.12782107, -0.047687534, 0.08897829, -0.076389425, 0.07659158, 0.0041352417, -0.080427416, 0.08401662, -0.022028731, 0.06474154  …  0.02411481, -0.00992975, -0.11688116, -0.027138984, -0.042349137, -0.05586761, -0.012177843, 0.017606206, -0.11491196, 0.0039530327]), layer_3 = (weight = Float32[-0.32711834 0.03474354 … -0.25653687 -0.2803519; -0.25338206 0.080845885 … -0.15960681 0.116571866; 0.09062245 -0.3132855 … 0.1213068 0.263132; -0.09302913 0.11489254 … -0.032247145 0.11322661], bias = Float32[-0.0649487, 0.13037996, -0.06775295, 0.06961753])), (layer_1 = (weight = Float32[-0.7985383 0.5683205 0.35762706 0.73707604; -0.061319787 -0.1698291 -0.7720229 -0.5664928; … ; -0.5250729 0.31209943 0.14238359 0.15709472; 0.15848564 -0.114602886 -0.1779177 -0.6734414], bias = Float32[0.21313144, -0.35993537, 0.436965, -0.039643135, -0.3355685, -0.50449854, -0.11579506, 0.34253487, 0.27382258, 0.13268717  …  0.29855734, -0.38998726, 0.42231402, -0.49196604, -0.36228624, 0.2965499, 0.34963357, 0.16680579, -0.073328264, -0.035726495]), layer_2 = (weight = Float32[0.113858834 0.22439969 … 0.06613132 0.23447856; 0.042067666 0.23668829 … 0.17005517 0.19604148; … ; 0.100262456 0.007425336 … -0.12292206 -0.14726229; -0.08203131 -0.1302521 … 0.076623924 -0.13465105], bias = Float32[-0.13098957, -0.049729154, 0.08845896, -0.07335012, 0.07364476, 0.005080295, -0.08091745, 0.08736962, -0.02406948, 0.0683  …  0.02558037, -0.006989359, -0.11823636, -0.027119715, -0.043967254, -0.058413144, -0.015030994, 0.021141404, -0.11727284, 0.003476899]), layer_3 = (weight = Float32[-0.32992968 0.032049768 … -0.25569668 -0.28317925; -0.25207213 0.08217947 … -0.16088565 0.11787388; 0.08864194 -0.31534076 … 0.12288207 0.26111114; -0.0896231 0.11814562 … -0.03270356 0.11647462], bias = Float32[-0.062212702, 0.12895152, -0.06565839, 0.0663055]))  …  (layer_1 = (weight = Float32[-0.8642322 0.51959974 0.38125366 0.7301147; 0.012227924 0.012897763 -0.8730856 -0.50603926; … ; -0.57210606 0.32993433 0.11673741 0.19313128; 0.38289875 0.30425158 -0.5197825 -0.7023212], bias = Float32[0.23059948, -0.3909284, 0.54047835, -0.13724223, -0.2259015, -0.62170535, 0.006207114, 0.16968219, 0.44869405, 0.10475772  …  0.19186494, -0.4761982, 0.46626744, -0.531174, -0.33513448, 0.2837692, 0.2848356, 0.21750131, -0.008974588, -0.04327586]), layer_2 = (weight = Float32[0.036546376 0.3184318 … -0.001885412 0.6413054; 0.14001606 -0.074564606 … 0.15648507 -0.14429817; … ; 0.031664364 0.059153672 … -0.18545759 -0.07539616; -0.108178765 -0.11003402 … 0.05086655 0.011880766], bias = Float32[-0.02593531, 0.08140213, 0.08393245, -0.02150179, 0.03593486, 0.033381578, -0.08693115, 0.12154997, -0.062326178, 0.13904735  …  -0.0135383215, 0.044050194, -0.14685859, 0.03225994, 0.07246305, -0.064179525, -0.036940314, 0.13580991, -0.1346709, -0.010749425]), layer_3 = (weight = Float32[-0.43785423 0.008134754 … -0.30692267 -0.3796847; -0.3653514 0.1736465 … -0.19215283 0.20678563; 0.24436869 -0.43365198 … 0.1426252 0.16837592; -0.029332984 0.1777881 … 0.0028432587 0.15403736], bias = Float32[-0.027141454, 0.14913608, -0.07116537, -0.0011289619])), (layer_1 = (weight = Float32[-0.8641219 0.5196009 0.38126084 0.7301125; 0.012212635 0.01289287 -0.87308365 -0.5060389; … ; -0.57210875 0.32993507 0.11673792 0.19313248; 0.3829726 0.3041174 -0.51955974 -0.7024173], bias = Float32[0.23043044, -0.3909311, 0.54042983, -0.13698606, -0.22595675, -0.62174845, 0.006218374, 0.16964863, 0.44871378, 0.10475541  …  0.19184443, -0.47625232, 0.4662525, -0.5311826, -0.33510935, 0.28319132, 0.2847714, 0.21751827, -0.009024123, -0.04318511]), layer_2 = (weight = Float32[0.036541384 0.31821966 … -0.001894722 0.64126056; 0.1400194 -0.07454067 … 0.1564902 -0.14423035; … ; 0.03152471 0.058529377 … -0.18556558 -0.07538326; -0.10818004 -0.11007311 … 0.050863653 0.011900533], bias = Float32[-0.025906881, 0.0813891, 0.08385508, -0.021506095, 0.035968784, 0.033392075, -0.08693132, 0.121234976, -0.062299803, 0.13901074  …  -0.013603874, 0.044133328, -0.14686534, 0.03222028, 0.07241977, -0.06422085, -0.036904175, 0.13575685, -0.13434663, -0.010731649]), layer_3 = (weight = Float32[-0.4378009 0.008137497 … -0.3069298 -0.37968126; -0.3656089 0.1736459 … -0.19211946 0.20679148; 0.24441311 -0.43364975 … 0.14260541 0.16837683; -0.02925033 0.1777907 … 0.0027962828 0.15403964], bias = Float32[-0.027083138, 0.14909422, -0.071137674, -0.0010774003])), (layer_1 = (weight = Float32[-0.86401254 0.51960206 0.38126802 0.7301102; 0.012197258 0.01288795 -0.8730817 -0.50603855; … ; -0.5721114 0.3299358 0.11673843 0.19313362; 0.38304675 0.30398214 -0.51933473 -0.7025144], bias = Float32[0.23026234, -0.39093587, 0.5403815, -0.13673007, -0.22601174, -0.62179226, 0.0062296367, 0.16961503, 0.44873407, 0.10475303  …  0.19182388, -0.4763071, 0.46623772, -0.53119147, -0.33508393, 0.28260878, 0.28470722, 0.21753742, -0.009073621, -0.043094836]), layer_2 = (weight = Float32[0.036536276 0.31800738 … -0.0019041422 0.6412156; 0.14002277 -0.074517116 … 0.15649535 -0.14416075; … ; 0.03138413 0.057904214 … -0.1856741 -0.07537044; -0.108181305 -0.11011149 … 0.050860785 0.011919611], bias = Float32[-0.02587912, 0.08137925, 0.08377769, -0.021510264, 0.036002856, 0.03340274, -0.08693145, 0.12092153, -0.062273335, 0.13897403  …  -0.013669614, 0.044217307, -0.14687207, 0.03218056, 0.07237711, -0.06426224, -0.036868, 0.13570426, -0.13402466, -0.010714424]), layer_3 = (weight = Float32[-0.43774754 0.008140268 … -0.30693752 -0.37967783; -0.36586425 0.17364533 … -0.19208618 0.2067974; 0.24445717 -0.4336475 … 0.14258577 0.16837773; -0.02916776 0.17779332 … 0.0027494507 0.15404192], bias = Float32[-0.027023783, 0.1490523, -0.07110996, -0.0010257371])), (layer_1 = (weight = Float32[-0.86390406 0.5196031 0.3812752 0.73010796; 0.012181792 0.012883002 -0.8730797 -0.50603825; … ; -0.57211405 0.3299365 0.11673895 0.19313473; 0.3831212 0.30384573 -0.5191075 -0.70261246], bias = Float32[0.2300952, -0.3909427, 0.54033333, -0.13647427, -0.22606646, -0.6218368, 0.006240901, 0.16958138, 0.44875494, 0.10475059  …  0.1918033, -0.47636256, 0.4662231, -0.5312006, -0.3350582, 0.28202155, 0.28464302, 0.21755877, -0.009123082, -0.04300504]), layer_2 = (weight = Float32[0.036531053 0.31779495 … -0.0019136742 0.6411705; 0.14002617 -0.07449395 … 0.15650047 -0.14408934; … ; 0.031242616 0.057278145 … -0.18578315 -0.07535769; -0.108182564 -0.11014916 … 0.050857946 0.0119380085], bias = Float32[-0.025852026, 0.08137259, 0.08370028, -0.021514298, 0.036037076, 0.033413578, -0.08693154, 0.12060965, -0.062246773, 0.13893722  …  -0.013735544, 0.04430214, -0.1468788, 0.03214077, 0.072335064, -0.06430369, -0.03683178, 0.13565217, -0.13370493, -0.010697745]), layer_3 = (weight = Float32[-0.4376941 0.008143067 … -0.30694583 -0.37967438; -0.36611745 0.1736448 … -0.19205299 0.20680343; 0.24450088 -0.43364528 … 0.14256628 0.16837859; -0.029085277 0.17779598 … 0.0027027652 0.1540442], bias = Float32[-0.02696338, 0.14901035, -0.07108221, -0.00097397104])), (layer_1 = (weight = Float32[-0.86379653 0.519604 0.38128242 0.7301057; 0.012166241 0.012878026 -0.8730777 -0.506038; … ; -0.5721167 0.3299372 0.11673948 0.19313578; 0.38319594 0.3037082 -0.51887804 -0.70271146], bias = Float32[0.229929, -0.3909516, 0.54028535, -0.13621865, -0.22612092, -0.62188196, 0.0062521645, 0.16954768, 0.4487764, 0.10474807  …  0.19178268, -0.47641873, 0.4662086, -0.53120995, -0.33503222, 0.2814296, 0.28457883, 0.21758233, -0.009172508, -0.04291573]), layer_2 = (weight = Float32[0.036525715 0.31758234 … -0.0019233194 0.64112526; 0.14002958 -0.07447117 … 0.15650558 -0.14401613; … ; 0.031100173 0.056651134 … -0.18589272 -0.075345; -0.108183816 -0.11018613 … 0.050855137 0.011955735], bias = Float32[-0.025825597, 0.08136913, 0.08362285, -0.021518199, 0.036071446, 0.03342459, -0.086931586, 0.12029933, -0.062220115, 0.13890031  …  -0.013801667, 0.044387832, -0.14688548, 0.032100912, 0.07229364, -0.064345196, -0.036795523, 0.13560055, -0.13338736, -0.010681604]), layer_3 = (weight = Float32[-0.4376406 0.008145896 … -0.30695477 -0.37967092; -0.36636853 0.17364427 … -0.1920199 0.20680955; 0.24454422 -0.43364304 … 0.14254694 0.16837944; -0.029002877 0.17779866 … 0.002656229 0.15404648], bias = Float32[-0.026901918, 0.14896834, -0.071054436, -0.0009221009])), (layer_1 = (weight = Float32[-0.8636899 0.5196049 0.38128963 0.7301034; 0.012150605 0.012873022 -0.87307566 -0.5060378; … ; -0.57211924 0.32993785 0.11674001 0.1931368; 0.38327098 0.3035695 -0.5186463 -0.7028114], bias = Float32[0.22976378, -0.3909626, 0.54023755, -0.13596323, -0.22617516, -0.62192786, 0.0062634246, 0.16951393, 0.44879842, 0.104745485  …  0.19176202, -0.4764756, 0.46619427, -0.5312196, -0.33500594, 0.28083286, 0.28451467, 0.21760812, -0.009221897, -0.042826906]), layer_2 = (weight = Float32[0.036520258 0.31736955 … -0.0019330797 0.6410799; 0.14003302 -0.07444877 … 0.1565107 -0.1439411; … ; 0.0309568 0.056023136 … -0.1860028 -0.07533238; -0.10818506 -0.110222414 … 0.05085236 0.011972801], bias = Float32[-0.02579983, 0.081368886, 0.0835454, -0.021521965, 0.036105964, 0.03343578, -0.086931586, 0.119990595, -0.06219336, 0.1388633  …  -0.0138679845, 0.04447439, -0.14689216, 0.032060992, 0.07225284, -0.06438676, -0.036759224, 0.13554943, -0.13307187, -0.010665997]), layer_3 = (weight = Float32[-0.43758705 0.008148754 … -0.3069643 -0.37966746; -0.36661747 0.17364378 … -0.1919869 0.20681576; 0.24458721 -0.4336408 … 0.14252776 0.16838026; -0.028920563 0.17780137 … 0.0026098457 0.15404876], bias = Float32[-0.026839394, 0.14892627, -0.07102663, -0.0008701259])), (layer_1 = (weight = Float32[-0.8635842 0.51960576 0.38129684 0.73010105; 0.012134885 0.01286799 -0.8730736 -0.5060376; … ; -0.5721218 0.3299385 0.11674055 0.19313776; 0.38334635 0.30342966 -0.51841223 -0.70291233], bias = Float32[0.22959952, -0.39097568, 0.5401899, -0.13570802, -0.22622918, -0.62197447, 0.0062746797, 0.16948013, 0.448821, 0.104742825  …  0.1917413, -0.47653314, 0.46618012, -0.5312295, -0.33497936, 0.2802313, 0.2844505, 0.21763614, -0.009271252, -0.04273857]), layer_2 = (weight = Float32[0.03651468 0.31715655 … -0.0019429566 0.6410344; 0.1400365 -0.07442676 … 0.15651579 -0.14386424; … ; 0.030812496 0.0553941 … -0.18611343 -0.07531982; -0.1081863 -0.11025801 … 0.050849605 0.0119892135], bias = Float32[-0.025774717, 0.081371866, 0.08346793, -0.021525599, 0.036140636, 0.03344715, -0.086931534, 0.119683444, -0.06216651, 0.13882618  …  -0.013934498, 0.044561815, -0.1468988, 0.032021005, 0.072212666, -0.06442839, -0.03672288, 0.1354988, -0.1327584, -0.010650916]), layer_3 = (weight = Float32[-0.43753347 0.008151641 … -0.30697447 -0.379664; -0.3668643 0.17364332 … -0.191954 0.20682208; 0.24462986 -0.43363854 … 0.14250873 0.16838107; -0.028838333 0.17780411 … 0.0025636188 0.15405104], bias = Float32[-0.026775802, 0.14888416, -0.07099879, -0.00081804546])), (layer_1 = (weight = Float32[-0.8634795 0.5196065 0.38130406 0.7300987; 0.012119084 0.012862929 -0.8730715 -0.5060374; … ; -0.57212436 0.32993913 0.11674109 0.19313869; 0.38342202 0.30328867 -0.5181759 -0.7030142], bias = Float32[0.22943625, -0.39099085, 0.54014254, -0.13545302, -0.226283, -0.6220218, 0.0062859273, 0.16944627, 0.44884416, 0.10474009  …  0.19172055, -0.47659138, 0.4661661, -0.5312397, -0.33495247, 0.27962488, 0.28438634, 0.21766639, -0.00932057, -0.042650726]), layer_2 = (weight = Float32[0.03650898 0.31694332 … -0.001952952 0.6409888; 0.14004 -0.07440514 … 0.15652087 -0.14378555; … ; 0.030667264 0.054763965 … -0.18622458 -0.07530731; -0.10818753 -0.110292934 … 0.050846882 0.012004982], bias = Float32[-0.025750257, 0.08137807, 0.08339044, -0.021529099, 0.03617546, 0.033458706, -0.08693143, 0.11937789, -0.06213956, 0.13878895  …  -0.014001208, 0.04465011, -0.14690544, 0.03198095, 0.07217312, -0.06447008, -0.03668649, 0.13544868, -0.13244684, -0.010636355]), layer_3 = (weight = Float32[-0.43747982 0.008154558 … -0.30698523 -0.37966052; -0.36710903 0.17364287 … -0.19192122 0.20682849; 0.24467215 -0.43363628 … 0.14248987 0.16838184; -0.028756186 0.17780688 … 0.0025175523 0.15405332], bias = Float32[-0.026711138, 0.148842, -0.070970915, -0.0007658592])), (layer_1 = (weight = Float32[-0.8633757 0.5196071 0.3813113 0.73009634; 0.012103201 0.0128578385 -0.87306935 -0.5060373; … ; -0.57212687 0.32993972 0.11674164 0.19313957; 0.383498 0.3031465 -0.51793724 -0.703117], bias = Float32[0.22927396, -0.3910081, 0.5400953, -0.13519824, -0.22633663, -0.62206984, 0.006297165, 0.16941236, 0.4488679, 0.10473729  …  0.19169976, -0.47665033, 0.46615225, -0.5312501, -0.3349253, 0.27901357, 0.28432217, 0.21769887, -0.009369853, -0.04256337]), layer_2 = (weight = Float32[0.03650316 0.3167298 … -0.0019630678 0.64094317; 0.14004353 -0.0743839 … 0.15652595 -0.14370503; … ; 0.030521102 0.054132678 … -0.18633626 -0.07529485; -0.10818875 -0.11032719 … 0.050844185 0.012020115], bias = Float32[-0.025726443, 0.081387505, 0.08331292, -0.021532465, 0.036210436, 0.033470444, -0.086931266, 0.11907394, -0.06211251, 0.13875163  …  -0.0140681155, 0.044739276, -0.14691204, 0.03194083, 0.0721342, -0.064511836, -0.036650058, 0.13539906, -0.13213713, -0.010622308]), layer_3 = (weight = Float32[-0.43742615 0.008157506 … -0.30699658 -0.37965703; -0.36735168 0.17364246 … -0.19188854 0.206835; 0.2447141 -0.433634 … 0.14247116 0.16838259; -0.028674124 0.17780967 … 0.0024716505 0.1540556], bias = Float32[-0.026645398, 0.1487998, -0.070943005, -0.0007135671])), (layer_1 = (weight = Float32[-0.851305 0.5520059 0.348024 0.7627709; 0.04632334 -0.009440457 -0.84029984 -0.5389861; … ; -0.5488786 0.30675948 0.13969672 0.16979802; 0.35596856 0.29062322 -0.50772625 -0.72199595], bias = Float32[0.28249922, -0.43494698, 0.5250078, -0.19528101, -0.2529324, -0.6427979, 0.027838308, 0.15167223, 0.47351483, 0.12807012  …  0.17229655, -0.4989907, 0.4908063, -0.5544986, -0.31426495, 0.30664715, 0.32898936, 0.22719437, -0.02488722, -0.07349131]), layer_2 = (weight = Float32[0.013762804 0.30396003 … -0.024262637 0.6697997; 0.16267325 -0.05548455 … 0.17881392 -0.12640543; … ; 0.07619615 0.09810671 … -0.14295425 -0.10001697; -0.085060485 -0.065308996 … 0.07442124 -0.032206066], bias = Float32[-0.053305052, 0.061915, 0.075528115, 0.004602924, 0.054446906, 0.05275818, -0.06672602, 0.17027602, -0.04425291, 0.1224139  …  0.031487763, 0.051902093, -0.12465611, 0.015045725, 0.057254896, -0.0809194, -0.019129237, 0.119476944, -0.15801132, 0.005014836]), layer_3 = (weight = Float32[-0.42147848 0.030777631 … -0.2665356 -0.35715672; -0.34472734 0.15084106 … -0.17458127 0.18331176; 0.25937232 -0.41101426 … 0.12505206 0.19113027; -0.01800699 0.20046614 … -0.010197557 0.1767238], bias = Float32[-0.014718303, 0.13196266, -0.053407278, 0.013244771]))], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), Float32[192.39882, 135.20038, 113.5306, 102.90519, 96.617905, 92.35046, 89.20726, 86.74315, 84.62894, 82.736435  …  56.736588, 56.623013, 56.5094, 56.395733, 56.28199, 56.16816, 56.054214, 55.940136, 55.825912, 27.031372])

We can now animate the training to get a better understanding of the fit.

predict(y0, t, p, state) = begin
    node, _, _ = neural_ode(t, length(y0))
    Array(node(y0, p, state)[1])
end

function plot_pred(t_train, y_train, t_grid, rescale_t, rescale_y,
        num_iters, p, state, loss, y0 = y_train[:, 1])
    y_pred = predict(y0, t_grid, p, state)
    return plot_result(rescale_t(t_train), rescale_y(y_train),
        rescale_t(t_grid), rescale_y(y_pred), loss, num_iters)
end

function plot_pred(t, y, y_pred)
    plt = Plots.scatter(t, y; label = "Observation")
    Plots.plot!(plt, t, y_pred; label = "Prediction")
end

function plot_pred(t, y, t_pred, y_pred; kwargs...)
    plot_params = zip(eachrow(y), eachrow(y_pred), FEATURE_NAMES, UNITS)
    map(enumerate(plot_params)) do (i, (yᵢ, ŷᵢ, name, unit))
        plt = Plots.plot(t_pred, ŷᵢ; label = "Prediction", color = i,
            linewidth = 3, legend = nothing, title = name, kwargs...)
        Plots.scatter!(plt, t, yᵢ; label = "Observation", xlabel = "Time",
            ylabel = unit, markersize = 5, color = i)
    end
end

function plot_result(t, y, t_pred, y_pred, loss, num_iters; kwargs...)
    plts_preds = plot_pred(t, y, t_pred, y_pred; kwargs...)
    plot!(plts_preds[1]; ylim = (10, 40), legend = (0.65, 1.0))
    plot!(plts_preds[2]; ylim = (20, 100))
    plot!(plts_preds[3]; ylim = (2, 12))
    plot!(plts_preds[4]; ylim = (990, 1025))

    p_loss = Plots.plot(loss; label = nothing, linewidth = 3, title = "Loss",
        xlabel = "Iterations", xlim = (0, num_iters))
    plots = [plts_preds..., p_loss]
    plot(plots...; layout = grid(length(plots), 1), size = (900, 900))
end

function animate_training(
        plot_frame, t_train, y_train, ps, losses, obs_grid; pause_for = 300)
    obs_count = Dict(i - 1 => n for (i, n) in enumerate(obs_grid))
    is = [min(i, length(losses)) for i in 2:(length(losses) + pause_for)]
    @animate for i in is
        stage = Int(floor((i - 1) / length(losses) * length(obs_grid)))
        k = obs_count[stage]
        plot_frame(t_train[1:k], y_train[:, 1:k], ps[i], losses[1:i])
    end every 2
end

num_iters = length(losses)
t_train_grid = collect(range(extrema(t_train)...; length = 500))
rescale_t(x) = t_scale .* x .+ t_mean
rescale_y(x) = y_scale .* x .+ y_mean
function plot_frame(t, y, p, loss)
    plot_pred(t, y, t_train_grid, rescale_t, rescale_y, num_iters, p, state, loss)
end
anim = animate_training(plot_frame, t_train, y_train, ps, losses, obs_grid)
gif(anim, "node_weather_forecast_training.gif")
Example block output

Looks good! But how well does the model forecast?

function plot_extrapolation(t_train, y_train, t_test, y_test, t̂, ŷ)
    plts = plot_pred(t_train, y_train, t̂, ŷ)
    for (i, (plt, y)) in enumerate(zip(plts, eachrow(y_test)))
        scatter!(plt, t_test, y; color = i, markerstrokecolor = :white,
            label = "Test observation")
    end

    plot!(plts[1]; ylim = (10, 40), legend = :topleft)
    plot!(plts[2]; ylim = (20, 100))
    plot!(plts[3]; ylim = (2, 12))
    plot!(plts[4]; ylim = (990, 1025))
    plot(plts...; layout = grid(length(plts), 1), size = (900, 900))
end

t_grid = collect(range(minimum(t_train), maximum(t_test); length = 500))
y_pred = predict(y_train[:, 1], t_grid, ps[end], state)
plot_extrapolation(rescale_t(t_train), rescale_y(y_train), rescale_t(t_test),
    rescale_y(y_test), rescale_t(t_grid), rescale_y(y_pred))
Example block output

While there is some drift in the weather patterns, the model extrapolates very well!