Logging Solver Information¶
One of the key features of torchode is that all components are replaceable and any components can log its own outputs (captured in a dictionary called stats). This means that you can inject your own code and log anything information that is relevant for your usecase. In this example, we will create a step size controller wrapper that logs the step times t and all accept decisions, i.e. if each step was accepted or rejected by the step size controller.
We begin by importing relevant modules and defining a generic model class.
import torch
import torch.nn as nn
import torchode as to
from torchode.step_size_controllers import StepSizeController
torch.random.manual_seed(180819023);
class Model(nn.Module):
def __init__(self, n_features, n_hidden):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(n_features, n_hidden),
nn.Softplus(),
nn.Linear(n_hidden, n_hidden),
nn.Softplus(),
nn.Linear(n_hidden, n_features)
)
def forward(self, t, y):
return self.layers(y)
Now we define the wrapper that will track the step size data. By deferring the actual functionality to another controller, we can re-use the existing controller implementations and focus on collecting the information that we care about, in this case the integration time points t, the step size dt and whether each step was accepted.
To define this custom controller, we just have to satisfy the StepSizeController interface. For the actual functionality we defer to another controller instance. In init we additionally initialize fields in the statistics dictionary for the current solve to capture t and so on. The adapt_step_size method then records the information into those fields.
Note that you could proceed in a similar way to track information about the stepping methods, e.g. dopri5, by defining a SingleStepMethod.
class StepSizeControllerTracker(StepSizeController):
"""A wrapper that collects time step and step acceptance information."""
def __init__(self, controller: StepSizeController):
super().__init__()
self.controller = controller
def init(self, term, problem, method_order: int, dt0, *, stats, args):
stats["all_t"] = []
stats["all_dt"] = []
stats["all_accept"] = []
return self.controller.init(
term, problem, method_order, dt0, stats=stats, args=args
)
def adapt_step_size(self, t0, dt, y0, step, state, stats):
accept, dt_next, state, status = self.controller.adapt_step_size(
t0, dt, y0, step, state, stats
)
stats["all_t"].append(t0)
stats["all_dt"].append(dt)
stats["all_accept"].append(accept)
return accept, dt_next, state, status
def merge_states(self, running, current, previous):
return self.controller.merge_states(running, current, previous)
Next, we construct a solver and wrap the step size controller with our tracker.
n_features = 5
batch_size = 3
model = Model(n_features=n_features, n_hidden=32)
dev = torch.device("cpu")
term = to.ODETerm(model)
step_method = to.Dopri5(term=term)
step_size_controller = to.IntegralController(atol=1e-6, rtol=1e-3, term=term)
step_size_controller = StepSizeControllerTracker(step_size_controller)
adjoint = to.AutoDiffAdjoint(step_method, step_size_controller).to(dev)
Finally, we generate some example data and evaluate the ODE defined by a randomly initialized model.
t_eval = torch.tile(torch.linspace(0.0, 3.0, 10), (batch_size, 1))
problem = to.InitialValueProblem(y0=torch.zeros((batch_size, n_features)).to(dev), t_eval=t_eval.to(dev))
sol = adjoint.solve(problem)
In the end, we can inspect the statistics recorded in the solution object and see that our custom step size controller collected the data.
print(torch.stack(sol.stats["all_t"]))
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.0000e-04, 1.0000e-04, 1.0000e-04],
[1.1000e-03, 1.1000e-03, 1.1000e-03],
[1.1100e-02, 1.1100e-02, 1.1100e-02],
[1.1110e-01, 1.1110e-01, 1.1110e-01],
[9.3798e-01, 9.3798e-01, 9.3798e-01]], grad_fn=<StackBackward0>)
print(torch.stack(sol.stats["all_dt"]))
tensor([[1.0000e-04, 1.0000e-04, 1.0000e-04],
[1.0000e-03, 1.0000e-03, 1.0000e-03],
[1.0000e-02, 1.0000e-02, 1.0000e-02],
[1.0000e-01, 1.0000e-01, 1.0000e-01],
[8.2688e-01, 8.2688e-01, 8.2688e-01],
[2.0620e+00, 2.0620e+00, 2.0620e+00]], grad_fn=<StackBackward0>)
print(torch.stack(sol.stats["all_accept"]))
tensor([[True, True, True],
[True, True, True],
[True, True, True],
[True, True, True],
[True, True, True],
[True, True, True]])