Coverage for src / qsmile / models / base.py: 100%

57 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 22:47 +0000

1"""SmileModel abstract base class.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from dataclasses import dataclass, field, replace 

7from typing import TYPE_CHECKING, ClassVar, Self 

8 

9import numpy as np 

10from numpy.typing import ArrayLike, NDArray 

11 

12from qsmile.core.coords import XCoord, YCoord 

13 

14if TYPE_CHECKING: 

15 import matplotlib.figure 

16 

17 from qsmile.data.meta import SmileMetadata 

18 

19 

20@dataclass 

21class SmileModel(ABC): 

22 """Abstract base for dataclass-based smile models. 

23 

24 Provides coordinate-aware evaluation, transformation, plotting, 

25 and default serialisation. Subclasses must define: 

26 

27 - Dataclass fields for the fitted parameters 

28 - ``native_x_coord``, ``native_y_coord``, ``param_names``, ``bounds`` ClassVars 

29 - ``_evaluate(x)`` instance method (raw formula in native coordinates) 

30 - ``initial_guess(x, y)`` static method 

31 - ``__post_init__()`` for validation (optional) 

32 """ 

33 

34 native_x_coord: ClassVar[XCoord] 

35 native_y_coord: ClassVar[YCoord] 

36 param_names: ClassVar[tuple[str, ...]] 

37 bounds: ClassVar[tuple[list[float], list[float]]] 

38 

39 metadata: SmileMetadata = field(repr=False) 

40 current_x_coord: XCoord = field(init=False) 

41 current_y_coord: YCoord = field(init=False) 

42 

43 def __post_init__(self) -> None: 

44 """Set current coords to native if not already set.""" 

45 if not hasattr(self, "_coords_set"): 

46 self.current_x_coord = self.__class__.native_x_coord 

47 self.current_y_coord = self.__class__.native_y_coord 

48 

49 @property 

50 def params(self) -> dict[str, float]: 

51 """Parameter name-to-value mapping.""" 

52 return {name: getattr(self, name) for name in self.param_names} 

53 

54 def to_array(self) -> NDArray[np.float64]: 

55 """Pack fitted parameters into a flat array using ``param_names`` order.""" 

56 return np.array([getattr(self, name) for name in self.param_names]) 

57 

58 @classmethod 

59 def from_array(cls, x: NDArray[np.float64], *, metadata: SmileMetadata) -> Self: 

60 """Reconstruct an instance from a flat parameter array. 

61 

62 Fitted parameters are mapped from *x* using ``param_names``. 

63 """ 

64 params = {name: float(x[i]) for i, name in enumerate(cls.param_names)} 

65 return cls(**params, metadata=metadata) 

66 

67 @abstractmethod 

68 def _evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64: 

69 """Compute model output at x values in native coordinates.""" 

70 ... 

71 

72 @staticmethod 

73 @abstractmethod 

74 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]: 

75 """Compute a heuristic initial guess from observed data.""" 

76 ... 

77 

78 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64: 

79 """Evaluate at *x* in current coordinates, transforming as needed.""" 

80 from qsmile.core.maps import ( 

81 apply_x_chain, 

82 apply_y_chain, 

83 compose_x_maps, 

84 compose_y_maps, 

85 ) 

86 

87 x_arr = np.asarray(x, dtype=np.float64) 

88 

89 # If already in native coords, skip transforms 

90 if ( 

91 self.current_x_coord == self.__class__.native_x_coord 

92 and self.current_y_coord == self.__class__.native_y_coord 

93 ): 

94 return self._evaluate(x_arr) 

95 

96 # Transform x: current → native 

97 x_chain = compose_x_maps(self.current_x_coord, self.__class__.native_x_coord) 

98 native_x = apply_x_chain(x_arr, x_chain, self.metadata) 

99 

100 # Evaluate in native coords 

101 native_y = np.asarray(self._evaluate(native_x), dtype=np.float64) 

102 

103 # Transform y: native → current 

104 y_chain = compose_y_maps(self.__class__.native_y_coord, self.current_y_coord) 

105 return apply_y_chain( 

106 native_y, 

107 native_x, 

108 y_chain, 

109 self.metadata, 

110 self.__class__.native_x_coord, 

111 self.current_x_coord, 

112 ) 

113 

114 def transform(self, target_x: XCoord, target_y: YCoord) -> Self: 

115 """Return a copy expressed in the target coordinate system.""" 

116 new = replace(self) 

117 object.__setattr__(new, "current_x_coord", target_x) 

118 object.__setattr__(new, "current_y_coord", target_y) 

119 return new 

120 

121 def plot( 

122 self, 

123 *, 

124 title: str = "Smile Model", 

125 n_points: int = 200, 

126 std_range: tuple[float, float] = (-5.0, 2.0), 

127 ax=None, 

128 **kwargs, 

129 ) -> matplotlib.figure.Figure: 

130 """Plot the model curve in current coordinates. 

131 

132 Parameters 

133 ---------- 

134 std_range : tuple[float, float] 

135 Plot range in standardised-strike units (sigma * sqrt(T)) as (lo, hi). 

136 Default (-5.0, 2.0). 

137 """ 

138 from qsmile.core.maps import apply_x_chain, compose_x_maps 

139 from qsmile.core.plot import plot_line 

140 

141 # Build grid in standardised-strike space, then map to current coords 

142 std_grid = np.linspace(std_range[0], std_range[1], n_points) 

143 

144 to_current = compose_x_maps(XCoord.StandardisedStrike, self.current_x_coord) 

145 x_grid = apply_x_chain(std_grid, to_current, self.metadata) 

146 

147 y_grid = np.asarray(self.evaluate(x_grid), dtype=np.float64) 

148 

149 return plot_line( 

150 x_grid, 

151 y_grid, 

152 xlabel=self.current_x_coord.name, 

153 ylabel=self.current_y_coord.name, 

154 title=title, 

155 ax=ax, 

156 **kwargs, 

157 )