This type of neural network, often abbreviated as ResNet, was introduced by He et al. (2015). We will discuss the motivation for and architecture of these networks, and their relation to ODEs. This will naturally lead to the concept of neural ODEs.
Consider a neural network with several layers . Each layer has its own weights . During training, we optimize the loss function . If the network is very deep, say , and for the gradient of it holds that
then the contribution of the first layers is negligible as the influence of their weights on is small. This leads to a cut-off in depth and the benefit in terms of generalization capabilities of deep networks is lost.
He et al. (2015) demonstrate that taking a deeper network can actually lead to an increase in training and testing error.
Thus, beyond a certain point, increasing the depth of a network can be counterproductive. Given this, we would like to come up with a network architecture that addresses the problem of vanishing gradients by ensuring that
This means requiring that when the weights of the network approach small values, the network should approach the identity mapping, and not the null mapping.
The problem of vanishing gradients outlined above is alleviated by introducing a deep residual learning framework. Instead of hoping that each few stacked layers directly fit a desired underlying mapping, we explicitly let these layers fit a residual mapping. Formally, denoting the desired underlying mapping as , we let the stacked nonlinear layers fit another mapping of . The original mapping is recast into . It is easier to optimize the residual mapping than to optimize the original, unreferenced mapping. To the extreme, if an identity mapping were optimal, it would be easier to push the residual to zero than to fit an identity mapping by a stack of nonlinear layers.
The formulation of can be realized by feedforward neural networks with "shortcut connections" skipping one or more layers forming so-called residual blocks. In our case, the shortcut connections simply perform identity mapping, and their outputs are added to the outputs of the stacked layers.
Identity shortcut connections neither add extra parameter nor computational complexity. The entire network can still be trained end-to-end with backpropagation, and can be easily implemented using common libraries without modifying the solvers.
To illustrate this we implement a simple fully connected neural network in Flux.jl.
using Flux
using BenchmarkTools: @benchmark
actual(x) = x
x_train = hcat(-10:10...)
y_train = actual.(x_train)
dense = Chain(
Dense(1 => 1, tanh),
Dense(1 => 1, tanh),
Dense(1 => 1, tanh),
Dense(1 => 1, identity),
)
loader = Flux.DataLoader((x_train, y_train), batchsize=8, shuffle=true);
function loop(model)
optim = Flux.setup(Adam(), model)
for epoch in 1:10_000
Flux.train!(model, loader, optim) do m, x, y
y_hat = m(x)
Flux.mse(y_hat, y)
end
Flux.mse(model(x_train), x_train) < 0.01 && break
end
end
@benchmark loop(dense) BenchmarkTools.Trial: 3010 samples with 1 evaluation.
Range (min … max): 124.333 μs … 15.881 ms ┊ GC (min … max): 0.00% … 52.93%
Time (median): 1.139 ms ┊ GC (median): 0.00%
Time (mean ± σ): 1.659 ms ± 1.938 ms ┊ GC (mean ± σ): 18.10% ± 15.30%
▁▆█▆▄▂
▃▇███████▆▅▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▁▁▂▁▂▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃
124 μs Histogram: frequency by time 10.8 ms <
Memory estimate: 169.88 KiB, allocs estimate: 2050. Comparing these benchmark results to a residual network with the same amount of layers reveals the benefits of these skip connections.
resnet = Chain(
SkipConnection(
Chain(
Dense(1 => 1, tanh),
Dense(1 => 1, tanh),
Dense(1 => 1, tanh),
),
+
),
Dense(1 => 1, identity),
)
@benchmark loop(resnet) BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 88.465 μs … 10.694 ms ┊ GC (min … max): 0.00% … 82.72%
Time (median): 93.265 μs ┊ GC (median): 0.00%
Time (mean ± σ): 113.836 μs ± 406.215 μs ┊ GC (mean ± σ): 15.60% ± 4.33%
▁▃▆▇███▇▆▅▄▃▂▁ ▁▁▁ ▁ ▂
▄█████████████████████▇█▇▆▆▇▇██▆▇▇████████████▇▇▆▆▇▇▆▇▆▆▆▆▅▅▅ █
88.5 μs Histogram: log(frequency) by time 130 μs <
Memory estimate: 104.50 KiB, allocs estimate: 1197. Let us first consider a residual block with a single linear layer which can be mathematically formulated in the following manner
This can easily be rewritten as
for some scalar .
Now consider a first-order system of (possibly nonlinear) ODEs, where given the IVP
we want to find . In order to solve this numerically, we can uniformly divide the temporal domain with a time-step and temporal nodes , where . Define the discrete solution as . Then, given , we can use a time-integrator to approximate the solution . We can consider a method motivated by the forward Euler integrator, where the LHS of (1) is approximated by
The RHS is approximated using a parameter as
where we allow the parameters to be different at each time-step. Putting these two together, we get exactly the relation of the ResNet. In other words, a ResNet is nothing but a discretization of a nonlinear system of ODEs. We make some comments to further strengthen this connection.
In a fully trained ResNet we are given and the weights of a network, and we predict . In a system of ODEs, we are given and , and we predict .
Training the ResNet means determining the parameters of the network so that is as close as possible to when for . When viewed from the analogous ODE point of view, training means determining the RHS by requiring to be as close as possible to , when for .
In a ResNet we are looking for a single that will map to for all .
Motivated by the connection between ResNets and ODEs, neural ODEs were proposed by Chen et al. (2019). Consider a system of ODEs given by
Given , we wish to find . The RHS, i.e., , is defined using a feed-forward neural network with parameters . The input to the network is while the output is having the same dimensions as . With this description, the system (2) is solved using a suitable time-marching scheme, such as forward Euler, Runge-Kutta, etc.
How do we use neural ODEs to solve a regression problem? Assume that you are given the labelled training data . Both, and , are assumed to have the same dimension . The key idea is to think of as points in -dimensional space that represent the initial state of the system and as points that represent the final state. Then the regression problem becomes finding the dynamics, that is the RHS, of (2) that will map the initial to the final points with minimal error. This means finding the parameters such that
is minimal. Here, is the solution at time to (2) with initial condition and the RHS is represented by a feed-forward neural network .
If we interpret the number of time steps in a neural ODE as the number of hidden layers in a ResNet, then the computational cost for both methods is . This is the cost associated with performing one forward propagation and one backward propagation. However, the memory cost associated with storing the weights of each layer, is different. For a neural ODE all weights are associated with the feed-forward network representing . Thus, the number of weights are independent of the number of time steps used to solve the ODE. On the other hand, for a ResNet the number of weights increases linearly with the number of layers, therefore the cost of storing them scales as .
With a neural ODE we can take the limit and study the converge since the size of the network remains unchanged. This is not feasible for ResNets where corresponds to network depth .
ResNet uses a forward Euler type method, but with neural ODEs any type of numerical ODE solver is feasible. Consider, for example, higher-order explicit time-integrator schemes like the Runge-Kutta methods that converge to the true solution at a faster rate.
Fit a neural network on the dynamics of the following ODE system with initial value.
The original equations are unknown to the neural network. The data set only consists of samples behaving in time like the ODE system describes. Feel free to train the model with more epochs for better results on your own machine.
using OrdinaryDiffEq
using DiffEqFlux, PlotlyJS
u0 = Float32[2.; 0.]
n_samples = 16
tspan = (0.0f0, 1.5f0)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length=n_samples)
prob = DiffEqFlux.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))
model = Chain(
x -> x.^3,
Dense(2, 50, tanh),
Dense(50, 2),
)
n_ode = NeuralODE(model, tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9)
ps = Flux.params(n_ode)
loss_n_ode() = sum(abs2, ode_data .- n_ode(u0))
data = Iterators.repeated((), 25) # epochs
Flux.train!(loss_n_ode, ps, data, ADAM(0.1))
pred = n_ode(u0)
traces = [
scatter(x=t, y=ode_data[1, :], name="u_1", mode="lines+markers"),
scatter(x=t, y=ode_data[2, :], name="u_2", mode="lines+markers"),
scatter(x=t, y=pred[1, :], name="u_1 pred", mode="lines+markers"),
scatter(x=t, y=pred[2, :], name="u_2 pred", mode="lines+markers"),
]
plt = plot(traces) We are not learning a solution to the original ODE. Instead, we are learning the tiny ODE system from which the ODE solution is generated. The neural network inside the neural ODE layer learns the function . Thus, it learned a compact representation of how the time series behaves, and it can easily extrapolate to what would happen with different initial conditions. It is also a very flexible method for learning such representations if your data is unevenly spaced. Just pass in the desired time steps and the ODE solver takes care of it.
Model the dynamics of the Lotka-Volterra system using a neural ODE. Refer to the Julia code above for help. See also notebooks/neural_ode.jl.
Ray, Pinti and Oberai, Deep Learning and Computational Physics (Lecture Notes), 2023, https://arxiv.org/pdf/2301.00942.pdf.
He, Zhang, Ren and Sun, Deep Residual Learning for Image Recognition, 2015, https://arxiv.org/pdf/1512.03385.pdf.
Chen, Rubanova, Bettencourt and Duvenaud, Neural Ordinary Differential Equations, 2019, https://arxiv.org/pdf/1806.07366.pdf.
Rackauckas, Innes, Ma, Bettencourt, White and Dixit, DiffEqFlux.jl – A Julia Library for Neural Differential Equations, 2019, https://julialang.org/blog/2019/01/fluxdiffeq/.