Coverage for src / qsmile / models / protocol.py: 100%
37 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
1"""SmileModel protocol and AbstractSmileModel base class."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from dataclasses import dataclass
7from typing import Any, ClassVar, Protocol, Self, TypeVar, runtime_checkable
9import numpy as np
10from numpy.typing import ArrayLike, NDArray
12from qsmile.core.coords import XCoord, YCoord
15@runtime_checkable
16class SmileModel(Protocol):
17 """Protocol that every smile model class must satisfy.
19 A conforming class acts as both a model definition (class-level
20 metadata such as native coordinates and bounds) and a fitted
21 parameter container (instance-level evaluation and serialisation).
23 Example::
25 result = fit(sd, SVIModel) # result.params is an SVIModel instance
26 result.params.evaluate(k)
27 """
29 native_x_coord: XCoord
30 native_y_coord: YCoord
31 param_names: tuple[str, ...]
32 bounds: tuple[list[float], list[float]]
34 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
35 """Compute model output at x values in native coordinates."""
36 ...
38 def to_array(self) -> NDArray[np.float64]:
39 """Pack parameters into a flat array."""
40 ...
42 @classmethod
43 def from_array(cls, x: NDArray[np.float64], **kwargs: Any) -> Self:
44 """Reconstruct an instance from a flat parameter array."""
45 ...
47 @staticmethod
48 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]:
49 """Compute a heuristic initial guess from observed data."""
50 ...
53@dataclass
54class AbstractSmileModel(ABC):
55 """Abstract base for dataclass-based smile models.
57 Provides default ``to_array()`` and ``from_array()`` implementations
58 that derive serialisation from ``param_names``. Subclasses must define:
60 - Dataclass fields for the fitted parameters
61 - ``native_x_coord``, ``native_y_coord``, ``param_names``, ``bounds`` ClassVars
62 - ``evaluate(x)`` instance method
63 - ``initial_guess(x, y)`` static method
64 - ``__post_init__()`` for validation (optional)
65 """
67 native_x_coord: ClassVar[XCoord]
68 native_y_coord: ClassVar[YCoord]
69 param_names: ClassVar[tuple[str, ...]]
70 bounds: ClassVar[tuple[list[float], list[float]]]
72 def to_array(self) -> NDArray[np.float64]:
73 """Pack fitted parameters into a flat array using ``param_names`` order."""
74 return np.array([getattr(self, name) for name in self.param_names])
76 @classmethod
77 def from_array(cls, x: NDArray[np.float64], **kwargs: Any) -> Self:
78 """Reconstruct an instance from a flat parameter array.
80 Fitted parameters are mapped from *x* using ``param_names``.
81 Additional keyword arguments (e.g. ``expiry``, ``forward``)
82 are forwarded to the constructor for context fields.
83 """
84 params = {name: float(x[i]) for i, name in enumerate(cls.param_names)}
85 return cls(**params, **kwargs)
87 @abstractmethod
88 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
89 """Compute model output at x values in native coordinates."""
90 ...
92 @staticmethod
93 @abstractmethod
94 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]:
95 """Compute a heuristic initial guess from observed data."""
96 ...
99M = TypeVar("M", bound=SmileModel)