Coverage for src / qsmile / models / fitting.py: 98%
43 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
1"""Smile fitting engine."""
3from __future__ import annotations
5import dataclasses
6from dataclasses import dataclass
7from typing import Any, Generic
9import numpy as np
10from numpy.typing import ArrayLike, NDArray
11from scipy.optimize import least_squares
13from qsmile.data.vols import SmileData
14from qsmile.models.protocol import M, SmileModel
17@dataclass
18class SmileResult(Generic[M]):
19 """Result of a smile model fit.
21 Generic over ``M`` so that ``result.params`` preserves the
22 concrete model type (e.g. ``SVIModel``).
24 Attributes:
25 ----------
26 params : M
27 Fitted parameter values.
28 residuals : NDArray[np.float64]
29 Per-observation residuals (model minus observed values in native coordinates).
30 rmse : float
31 Root mean square error of the fit.
32 success : bool
33 Whether the optimiser converged.
34 """
36 params: M
37 residuals: NDArray[np.float64]
38 rmse: float
39 success: bool
41 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
42 """Compute model output at arbitrary x values in native coordinates."""
43 return self.params.evaluate(x)
46def _context_for_model(model: type[SmileModel], sd: SmileData) -> dict[str, Any]:
47 """Extract non-param context fields from SmileData for models that need them.
49 Compares the model's dataclass fields against ``param_names`` to find
50 context fields (e.g. ``expiry``, ``forward`` for SABR). Values are
51 sourced from ``SmileData.metadata``.
52 """
53 if not dataclasses.is_dataclass(model):
54 return {}
55 all_field_names = {f.name for f in dataclasses.fields(model)}
56 context_fields = all_field_names - set(model.param_names)
57 context: dict[str, Any] = {}
58 for name in context_fields:
59 if hasattr(sd.metadata, name):
60 context[name] = getattr(sd.metadata, name)
61 return context
64def _residuals(
65 x: NDArray[np.float64],
66 model: type[SmileModel],
67 x_obs: NDArray[np.float64],
68 y_obs: NDArray[np.float64],
69 context: dict[str, Any],
70) -> NDArray[np.float64]:
71 """Residual function for least_squares: model - observed."""
72 fitted = model.from_array(x, **context)
73 y_model = np.asarray(fitted.evaluate(x_obs), dtype=np.float64)
74 return y_model - y_obs
77def fit(
78 chain: SmileData,
79 model: type[M],
80 initial_guess: M | None = None,
81) -> SmileResult[M]:
82 """Fit a smile model to market data.
84 Parameters
85 ----------
86 chain : SmileData
87 Market data to fit. Uses mid values for fitting.
88 Internally transforms to the model's native coordinates.
89 model : type[M]
90 A model class (e.g. ``SVIModel``) that defines native coordinates,
91 bounds, evaluation, and initial-guess heuristic.
92 initial_guess : M, optional
93 Initial parameter guess (e.g. an ``SVIModel(...)`` instance).
94 If None, the model's heuristic initial guess is computed from data.
96 Returns:
97 -------
98 SmileResult[M]
99 Fitted parameters, residuals, RMSE, and convergence status.
100 """
101 sd = chain.transform(model.native_x_coord, model.native_y_coord)
102 x_obs = sd.x
103 y_obs = sd.y_mid
105 x0 = initial_guess.to_array() if initial_guess is not None else model.initial_guess(x_obs, y_obs)
107 lower, upper = model.bounds
108 context = _context_for_model(model, sd)
110 result = least_squares(
111 _residuals,
112 x0,
113 args=(model, x_obs, y_obs, context),
114 bounds=(lower, upper),
115 method="trf",
116 max_nfev=10_000,
117 )
119 fitted_params = model.from_array(result.x, **context)
120 residuals = result.fun
121 rmse = float(np.sqrt(np.mean(residuals**2)))
123 return SmileResult(
124 params=fitted_params,
125 residuals=residuals,
126 rmse=rmse,
127 success=bool(result.success),
128 )