sciml_train and GalacticOptim.jl
sciml_train
is a heuristic-based training function built using GalacticOptim.jl. It incorporates the knowledge of many high level benchmarks to attempt and do the right thing.
GalacticOptim.jl is a package with a scope that is beyond your normal global optimization package. GalacticOptim.jl seeks to bring together all of the optimization packages it can find, local and global, into one unified Julia interface. This means, you learn one package and you learn them all! GalacticOptim.jl adds a few high-level features, such as integrating with automatic differentiation, to make its usage fairly simple for most cases, while allowing all of the options in a single unified interface.
sciml_train API
DiffEqFlux.sciml_train
— Functionsciml_train
Unconstrained Optimization
function sciml_train(loss, _θ, opt = DEFAULT_OPT, adtype = DEFAULT_AD,
_data = DEFAULT_DATA, args...;
cb = (args...) -> false, maxiters = get_maxiters(data),
kwargs...)
Box Constrained Optimization
function sciml_train(loss, θ, opt = DEFAULT_OPT, adtype = DEFAULT_AD,
data = DEFAULT_DATA, args...;
lower_bounds, upper_bounds,
cb = (args...) -> (false), maxiters = get_maxiters(data),
kwargs...)
Optimizer Choices and Arguments
For a full definition of the allowed optimizers and arguments, please see the GalacticOptim.jl documentation. As sciml_train is an interface over GalacticOptim.jl, all of its optimizers and arguments can be used from here.
Loss Functions and Callbacks
Loss functions in sciml_train
treat the first returned value as the return. For example, if one returns (1.0, [2.0])
, then the value the optimizer will see is 1.0
. The other values are passed to the callback function. The callback function is cb(p, args...)
where the arguments are the extra returns from the loss. This allows for reusing instead of recalculating. The callback function must return a boolean where if true
, then the optimizer will prematurely end the optimization. It is called after every successful step, something that is defined in an optimizer-dependent manner.
Default AD Choice
The current default AD choice is dependent on the number of parameters, where for <100 parameters ForwardDiff.jl is used, otherwise Zygote.jl is used. More refinements to the techniques are planned.
Default Optimizer Choice
By default, if the loss function is deterministic than an optimizer chain of ADAM -> BFGS is used, otherwise ADAM is used (and a choice of maxiters is required).