Coverage for src / qsmile / models / result.py: 100%
30 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
1"""Smile fitting engine."""
3from __future__ import annotations
5from dataclasses import dataclass
7import numpy as np
8from numpy.typing import NDArray
9from scipy.optimize import least_squares
11from qsmile.data.meta import SmileMetadata
12from qsmile.data.vols import VolData
13from qsmile.models.base import SmileModel
16@dataclass
17class SmileResult:
18 """Result of a smile model fit.
20 Attributes:
21 ----------
22 model : SmileModel
23 Fitted model instance (coordinate-aware).
24 residuals : NDArray[np.float64]
25 Per-observation residuals (model minus observed values in native coordinates).
26 rmse : float
27 Root mean square error of the fit.
28 success : bool
29 Whether the optimiser converged.
30 """
32 model: SmileModel
33 residuals: NDArray[np.float64]
34 rmse: float
35 success: bool
38def _residuals(
39 x: NDArray[np.float64],
40 model: type[SmileModel],
41 x_obs: NDArray[np.float64],
42 y_obs: NDArray[np.float64],
43 metadata: SmileMetadata,
44) -> NDArray[np.float64]:
45 """Residual function for least_squares: model - observed."""
46 fitted = model.from_array(x, metadata=metadata)
47 y_model = np.asarray(fitted._evaluate(x_obs), dtype=np.float64)
48 return y_model - y_obs
51def fit(
52 chain: VolData,
53 model: type[SmileModel],
54 initial_guess: SmileModel | None = None,
55) -> SmileResult:
56 """Fit a smile model to market data.
58 Parameters
59 ----------
60 chain : VolData
61 Market data to fit. Uses mid values for fitting.
62 Internally transforms to the model's native coordinates.
63 model : type[SmileModel]
64 A model class (e.g. ``SVIModel``) that defines native coordinates,
65 bounds, evaluation, and initial-guess heuristic.
66 initial_guess : SmileModel, optional
67 Initial parameter guess (e.g. an ``SVIModel(...)`` instance).
68 If None, the model's heuristic initial guess is computed from data.
70 Returns:
71 -------
72 SmileResult
73 Fitted model, residuals, RMSE, and convergence status.
74 """
75 sd = chain.transform(model.native_x_coord, model.native_y_coord)
76 x_obs = sd.x
77 y_obs = sd.y_mid
78 metadata = sd.metadata
80 x0 = initial_guess.to_array() if initial_guess is not None else model.initial_guess(x_obs, y_obs)
82 lower, upper = model.bounds
84 result = least_squares(
85 _residuals,
86 x0,
87 args=(model, x_obs, y_obs, metadata),
88 bounds=(lower, upper),
89 method="trf",
90 max_nfev=10_000,
91 )
93 fitted_params = model.from_array(result.x, metadata=metadata)
94 residuals = result.fun
95 rmse = float(np.sqrt(np.mean(residuals**2)))
97 return SmileResult(
98 model=fitted_params,
99 residuals=residuals,
100 rmse=rmse,
101 success=bool(result.success),
102 )