JIT Compilation¶
torchode is fully JIT compilable. By JIT compiling your model together with the ODE solver, you can speed up your model training as well as inference by eliminating the comparably slow Python interpreter from the forward pass. This means that the actual computations, e.g. matrix multiplication, can be scheduled more quickly and usage of your CPU/GPU increases while the wall-clock time of the computation decreases. However, your model can only be JIT compiled if it is written in TorchScript, a subset of Python.
Because of the way PyTorch's JIT works, we can also no longer use the simple solve_ivp interface. Instead, we have to construct the solver components ourselves before handing them over to the compiler. This is necessary because the JIT compiler requires that all "dynamic" parts of the computation are fixed in place at the time of computation, i.e. after compilation the code can only deal with tensors and literals and not use objects with dynamic behavior such as custom classes.
Let's begin by importing everything we need in this example.
import torch
import torch.nn as nn
import torchode as to
torch.random.manual_seed(180819023);
Now we define a simple neural ODE given by an MLP with two hidden layers.
class Model(nn.Module):
def __init__(self, n_features, n_hidden):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(n_features, n_hidden),
nn.Softplus(),
nn.Linear(n_hidden, n_hidden),
nn.Softplus(),
nn.Linear(n_hidden, n_features)
)
def forward(self, t, y):
return self.layers(y)
n_features = 5
model = Model(n_features=n_features, n_hidden=32)
Next, we construct the solver components and then put them together into the solver AutoDiffAdjoint (that computes the parameter derivatives by backpropagating through the solver). Note how we have to pass the model into the step method and step size controller so that it is fixed when we JIT compile the solver.
dev = torch.device("cpu")
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller).to(dev)
Next, we compile the solver and our model.
adjoint_jit = torch.jit.script(adjoint)
As a last step, we have to combine the initial condition y0 and the evaluation points t_eval into a problem instance.
batch_size = 3
t_eval = torch.tile(torch.linspace(0.0, 3.0, 10), (batch_size, 1))
problem = to.InitialValueProblem(y0=torch.zeros((batch_size, 5)).to(dev), t_eval=t_eval.to(dev))
Here we see that both the normal and the compiled solver get the same stats and approximately the same result. The results are not identical because of JIT compilation. Most likely, the compilation reorders some operations that leads to small differences because of the floating point format.
sol = adjoint.solve(problem)
sol_jit = adjoint_jit.solve(problem)
print(sol.stats)
print(sol_jit.stats)
print("Max absolute difference", float((sol.ys - sol_jit.ys).abs().max()))
{'n_f_evals': tensor([38, 38, 38]), 'n_steps': tensor([6, 6, 6]), 'n_accepted': tensor([6, 6, 6]), 'n_initialized': tensor([10, 10, 10])}
{'n_f_evals': tensor([38, 38, 38]), 'n_steps': tensor([6, 6, 6]), 'n_accepted': tensor([6, 6, 6]), 'n_initialized': tensor([10, 10, 10])}
Max absolute difference 1.2114644050598145e-05
And finally we can compare the two in terms of runtime.
%%timeit
adjoint.solve(problem)
6.63 ms ± 742 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# A second warm up run. For some reason the second call to the compiled solver triggers more compilation
# which we don't want to measure.
adjoint_jit.solve(problem);
%%timeit
adjoint_jit.solve(problem)
3.63 ms ± 89.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)