Multiple Shooting Functionality
The form of multiple shooting found here is a specialized form for implicit layer deep learning (known as data shooting) which assumes full observability of the underlying dynamics and lack of noise. For a more general implementation of multiple shooting, see JuliaSimModelOptimizer. For an implementation more directly tied to parameter estimation against data, see DiffEqParamEstim.jl.
DiffEqFlux.multiple_shoot
— Functionmultiple_shoot(p, ode_data, tsteps, prob, loss_function,
[continuity_loss = _default_continuity_loss], solver, group_size;
continuity_term = 100, kwargs...)
Returns a total loss after trying a 'Direct multiple shooting' on ODE data and an array of predictions from each of the groups (smaller intervals). In Direct Multiple Shooting, the Neural Network divides the interval into smaller intervals and solves for them separately. The default continuity term is 100, implying any losses arising from the non-continuity of 2 different groups will be scaled by 100.
Arguments:
p
: The parameters of the Neural Network to be trained.ode_data
: Original Data to be modelled.tsteps
: Timesteps on which ode_data was calculated.prob
: ODE problem that the Neural Network attempts to solve.loss_function
: Any arbitrary function to calculate loss.continuity_loss
: Function that takes states $\hat{u}_{end}$ of group $k$ and $u_{0}$ of group $k+1$ as input and calculates prediction continuity loss between them. If no customcontinuity_loss
is specified,sum(abs, û_end - u_0)
is used.solver
: ODE Solver algorithm.group_size
: The group size achieved after splitting the ode_data into equal sizes.continuity_term
: Weight term to ensure continuity of predictions throughout different groups.kwargs
: Additional arguments splatted to the ODE solver. Refer to the Local Sensitivity Analysis and Common Solver Arguments documentation for more details.
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty whenever the last point of any group doesn't coincide with the first point of next group.
multiple_shoot(p, ode_data, tsteps, ensembleprob, ensemblealg, loss_function,
[continuity_loss = _default_continuity_loss], solver, group_size;
continuity_term = 100, kwargs...)
Returns a total loss after trying a 'Direct multiple shooting' on ODE data and an array of predictions from each of the groups (smaller intervals). In Direct Multiple Shooting, the Neural Network divides the interval into smaller intervals and solves for them separately. The default continuity term is 100, implying any losses arising from the non-continuity of 2 different groups will be scaled by 100.
Arguments:
p
: The parameters of the Neural Network to be trained.ode_data
: Original Data to be modelled.tsteps
: Timesteps on which ode_data was calculated.ensemble_prob
: Ensemble problem that the Neural Network attempts to solve.ensemble_alg
: Ensemble algorithm, e.g.EnsembleThreads()
.prob
: ODE problem that the Neural Network attempts to solve.loss_function
: Any arbitrary function to calculate loss.continuity_loss
: Function that takes states $\hat{u}_{end}$ of group $k$ and $u_{0}$ of group $k+1$ as input and calculates prediction continuity loss between them. If no customcontinuity_loss
is specified,sum(abs, û_end - u_0)
is used.solver
: ODE Solver algorithm.group_size
: The group size achieved after splitting the ode_data into equal sizes.continuity_term
: Weight term to ensure continuity of predictions throughout different groups.kwargs
: Additional arguments splatted to the ODE solver. Refer to the Local Sensitivity Analysis and Common Solver Arguments documentation for more details.
The parameter 'continuity_term' should be a relatively big number to enforce a large penalty whenever the last point of any group doesn't coincide with the first point of next group.
DiffEqFlux.group_ranges
— Functiongroup_ranges(datasize, groupsize)
Get ranges that partition data of length datasize
in groups of groupsize
observations. If the data isn't perfectly dividable by groupsize
, the last group contains the reminding observations.
Arguments:
datasize
: amount of data points to be partitioned.groupsize
: maximum amount of observations in each group.
Example:
julia> group_ranges(10, 5)
3-element Vector{UnitRange{Int64}}:
1:5
5:9
9:10