Coverage for src / qsmile / models / base.py: 100%
57 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
1"""SmileModel abstract base class."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from dataclasses import dataclass, field, replace
7from typing import TYPE_CHECKING, ClassVar, Self
9import numpy as np
10from numpy.typing import ArrayLike, NDArray
12from qsmile.core.coords import XCoord, YCoord
14if TYPE_CHECKING:
15 import matplotlib.figure
17 from qsmile.data.meta import SmileMetadata
20@dataclass
21class SmileModel(ABC):
22 """Abstract base for dataclass-based smile models.
24 Provides coordinate-aware evaluation, transformation, plotting,
25 and default serialisation. Subclasses must define:
27 - Dataclass fields for the fitted parameters
28 - ``native_x_coord``, ``native_y_coord``, ``param_names``, ``bounds`` ClassVars
29 - ``_evaluate(x)`` instance method (raw formula in native coordinates)
30 - ``initial_guess(x, y)`` static method
31 - ``__post_init__()`` for validation (optional)
32 """
34 native_x_coord: ClassVar[XCoord]
35 native_y_coord: ClassVar[YCoord]
36 param_names: ClassVar[tuple[str, ...]]
37 bounds: ClassVar[tuple[list[float], list[float]]]
39 metadata: SmileMetadata = field(repr=False)
40 current_x_coord: XCoord = field(init=False)
41 current_y_coord: YCoord = field(init=False)
43 def __post_init__(self) -> None:
44 """Set current coords to native if not already set."""
45 if not hasattr(self, "_coords_set"):
46 self.current_x_coord = self.__class__.native_x_coord
47 self.current_y_coord = self.__class__.native_y_coord
49 @property
50 def params(self) -> dict[str, float]:
51 """Parameter name-to-value mapping."""
52 return {name: getattr(self, name) for name in self.param_names}
54 def to_array(self) -> NDArray[np.float64]:
55 """Pack fitted parameters into a flat array using ``param_names`` order."""
56 return np.array([getattr(self, name) for name in self.param_names])
58 @classmethod
59 def from_array(cls, x: NDArray[np.float64], *, metadata: SmileMetadata) -> Self:
60 """Reconstruct an instance from a flat parameter array.
62 Fitted parameters are mapped from *x* using ``param_names``.
63 """
64 params = {name: float(x[i]) for i, name in enumerate(cls.param_names)}
65 return cls(**params, metadata=metadata)
67 @abstractmethod
68 def _evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
69 """Compute model output at x values in native coordinates."""
70 ...
72 @staticmethod
73 @abstractmethod
74 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]:
75 """Compute a heuristic initial guess from observed data."""
76 ...
78 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
79 """Evaluate at *x* in current coordinates, transforming as needed."""
80 from qsmile.core.maps import (
81 apply_x_chain,
82 apply_y_chain,
83 compose_x_maps,
84 compose_y_maps,
85 )
87 x_arr = np.asarray(x, dtype=np.float64)
89 # If already in native coords, skip transforms
90 if (
91 self.current_x_coord == self.__class__.native_x_coord
92 and self.current_y_coord == self.__class__.native_y_coord
93 ):
94 return self._evaluate(x_arr)
96 # Transform x: current → native
97 x_chain = compose_x_maps(self.current_x_coord, self.__class__.native_x_coord)
98 native_x = apply_x_chain(x_arr, x_chain, self.metadata)
100 # Evaluate in native coords
101 native_y = np.asarray(self._evaluate(native_x), dtype=np.float64)
103 # Transform y: native → current
104 y_chain = compose_y_maps(self.__class__.native_y_coord, self.current_y_coord)
105 return apply_y_chain(
106 native_y,
107 native_x,
108 y_chain,
109 self.metadata,
110 self.__class__.native_x_coord,
111 self.current_x_coord,
112 )
114 def transform(self, target_x: XCoord, target_y: YCoord) -> Self:
115 """Return a copy expressed in the target coordinate system."""
116 new = replace(self)
117 object.__setattr__(new, "current_x_coord", target_x)
118 object.__setattr__(new, "current_y_coord", target_y)
119 return new
121 def plot(
122 self,
123 *,
124 title: str = "Smile Model",
125 n_points: int = 200,
126 std_range: tuple[float, float] = (-5.0, 2.0),
127 ax=None,
128 **kwargs,
129 ) -> matplotlib.figure.Figure:
130 """Plot the model curve in current coordinates.
132 Parameters
133 ----------
134 std_range : tuple[float, float]
135 Plot range in standardised-strike units (sigma * sqrt(T)) as (lo, hi).
136 Default (-5.0, 2.0).
137 """
138 from qsmile.core.maps import apply_x_chain, compose_x_maps
139 from qsmile.core.plot import plot_line
141 # Build grid in standardised-strike space, then map to current coords
142 std_grid = np.linspace(std_range[0], std_range[1], n_points)
144 to_current = compose_x_maps(XCoord.StandardisedStrike, self.current_x_coord)
145 x_grid = apply_x_chain(std_grid, to_current, self.metadata)
147 y_grid = np.asarray(self.evaluate(x_grid), dtype=np.float64)
149 return plot_line(
150 x_grid,
151 y_grid,
152 xlabel=self.current_x_coord.name,
153 ylabel=self.current_y_coord.name,
154 title=title,
155 ax=ax,
156 **kwargs,
157 )