Coverage for src / qsmile / models / fitting.py: 98%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-04 21:47 +0000

1"""Smile fitting engine.""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6from dataclasses import dataclass 

7from typing import Any, Generic 

8 

9import numpy as np 

10from numpy.typing import ArrayLike, NDArray 

11from scipy.optimize import least_squares 

12 

13from qsmile.data.vols import SmileData 

14from qsmile.models.protocol import M, SmileModel 

15 

16 

17@dataclass 

18class SmileResult(Generic[M]): 

19 """Result of a smile model fit. 

20 

21 Generic over ``M`` so that ``result.params`` preserves the 

22 concrete model type (e.g. ``SVIModel``). 

23 

24 Attributes: 

25 ---------- 

26 params : M 

27 Fitted parameter values. 

28 residuals : NDArray[np.float64] 

29 Per-observation residuals (model minus observed values in native coordinates). 

30 rmse : float 

31 Root mean square error of the fit. 

32 success : bool 

33 Whether the optimiser converged. 

34 """ 

35 

36 params: M 

37 residuals: NDArray[np.float64] 

38 rmse: float 

39 success: bool 

40 

41 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64: 

42 """Compute model output at arbitrary x values in native coordinates.""" 

43 return self.params.evaluate(x) 

44 

45 

46def _context_for_model(model: type[SmileModel], sd: SmileData) -> dict[str, Any]: 

47 """Extract non-param context fields from SmileData for models that need them. 

48 

49 Compares the model's dataclass fields against ``param_names`` to find 

50 context fields (e.g. ``expiry``, ``forward`` for SABR). Values are 

51 sourced from ``SmileData.metadata``. 

52 """ 

53 if not dataclasses.is_dataclass(model): 

54 return {} 

55 all_field_names = {f.name for f in dataclasses.fields(model)} 

56 context_fields = all_field_names - set(model.param_names) 

57 context: dict[str, Any] = {} 

58 for name in context_fields: 

59 if hasattr(sd.metadata, name): 

60 context[name] = getattr(sd.metadata, name) 

61 return context 

62 

63 

64def _residuals( 

65 x: NDArray[np.float64], 

66 model: type[SmileModel], 

67 x_obs: NDArray[np.float64], 

68 y_obs: NDArray[np.float64], 

69 context: dict[str, Any], 

70) -> NDArray[np.float64]: 

71 """Residual function for least_squares: model - observed.""" 

72 fitted = model.from_array(x, **context) 

73 y_model = np.asarray(fitted.evaluate(x_obs), dtype=np.float64) 

74 return y_model - y_obs 

75 

76 

77def fit( 

78 chain: SmileData, 

79 model: type[M], 

80 initial_guess: M | None = None, 

81) -> SmileResult[M]: 

82 """Fit a smile model to market data. 

83 

84 Parameters 

85 ---------- 

86 chain : SmileData 

87 Market data to fit. Uses mid values for fitting. 

88 Internally transforms to the model's native coordinates. 

89 model : type[M] 

90 A model class (e.g. ``SVIModel``) that defines native coordinates, 

91 bounds, evaluation, and initial-guess heuristic. 

92 initial_guess : M, optional 

93 Initial parameter guess (e.g. an ``SVIModel(...)`` instance). 

94 If None, the model's heuristic initial guess is computed from data. 

95 

96 Returns: 

97 ------- 

98 SmileResult[M] 

99 Fitted parameters, residuals, RMSE, and convergence status. 

100 """ 

101 sd = chain.transform(model.native_x_coord, model.native_y_coord) 

102 x_obs = sd.x 

103 y_obs = sd.y_mid 

104 

105 x0 = initial_guess.to_array() if initial_guess is not None else model.initial_guess(x_obs, y_obs) 

106 

107 lower, upper = model.bounds 

108 context = _context_for_model(model, sd) 

109 

110 result = least_squares( 

111 _residuals, 

112 x0, 

113 args=(model, x_obs, y_obs, context), 

114 bounds=(lower, upper), 

115 method="trf", 

116 max_nfev=10_000, 

117 ) 

118 

119 fitted_params = model.from_array(result.x, **context) 

120 residuals = result.fun 

121 rmse = float(np.sqrt(np.mean(residuals**2))) 

122 

123 return SmileResult( 

124 params=fitted_params, 

125 residuals=residuals, 

126 rmse=rmse, 

127 success=bool(result.success), 

128 )