Recreate torchdiffeq's defaults in torchode¶
import torch
import torch.nn as nn
import torchode as to
import torchdiffeq as tde
torch.random.manual_seed(180819023);
Consider a two-layer, randomly initialized MLP.
class Model(nn.Module):
def __init__(self, n_features, n_hidden):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(n_features, n_hidden),
nn.Softplus(),
nn.Linear(n_hidden, n_hidden),
nn.Softplus(),
nn.Linear(n_hidden, n_features)
)
def forward(self, t, y):
return self.layers(y)
n_features = 5
model = Model(n_features, 16)
We would like to evaluate this model on the following initial data y0 and time points t_eval.
batch_size = 16
n_steps = 10
y0 = torch.randn((batch_size, n_features))
t_eval = torch.linspace(0.0, 1.0, n_steps)
With torchdiffeq that looks as follows.
sol_tde = tde.odeint(model, y0, t_eval)
In torchode, we set up the components and then put them together to create a solver from them that backpropagates by autodiffing through the solver operations (discretize-then-optimize).
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-9, rtol=1e-7, term=term)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller)
Now we can reuse the solver in adjoint for any problem we want to solve, for example the one from above. For that we create a problem instance and pass it to the solver. Note that we have to repeat the evaluation time points for each sample in the batch because torchode solves a separate ODE for each sample.
problem = to.InitialValueProblem(y0=y0, t_eval=t_eval.repeat((batch_size, 1)))
sol = adjoint.solve(problem)
Comparing the two solutions shows that they are very close.
abs_err = (sol.ys - sol_tde.transpose(0, 1)).abs()
abs_err.mean().item(), abs_err.max().item()
(3.3444638347646105e-07, 6.198883056640625e-06)
Finally, let's look at the solution statistics that torchode gives us.
sol.stats
{'n_f_evals': tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]),
'n_steps': tensor([5, 5, 6, 8, 5, 6, 6, 5, 6, 7, 5, 5, 5, 5, 5, 7]),
'n_accepted': tensor([5, 5, 6, 7, 5, 6, 5, 5, 5, 7, 5, 5, 5, 5, 5, 7]),
'n_initialized': tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10])}