Coverage for src / qsmile / models / protocol.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-04 21:47 +0000

1"""SmileModel protocol and AbstractSmileModel base class.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from dataclasses import dataclass 

7from typing import Any, ClassVar, Protocol, Self, TypeVar, runtime_checkable 

8 

9import numpy as np 

10from numpy.typing import ArrayLike, NDArray 

11 

12from qsmile.core.coords import XCoord, YCoord 

13 

14 

15@runtime_checkable 

16class SmileModel(Protocol): 

17 """Protocol that every smile model class must satisfy. 

18 

19 A conforming class acts as both a model definition (class-level 

20 metadata such as native coordinates and bounds) and a fitted 

21 parameter container (instance-level evaluation and serialisation). 

22 

23 Example:: 

24 

25 result = fit(sd, SVIModel) # result.params is an SVIModel instance 

26 result.params.evaluate(k) 

27 """ 

28 

29 native_x_coord: XCoord 

30 native_y_coord: YCoord 

31 param_names: tuple[str, ...] 

32 bounds: tuple[list[float], list[float]] 

33 

34 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64: 

35 """Compute model output at x values in native coordinates.""" 

36 ... 

37 

38 def to_array(self) -> NDArray[np.float64]: 

39 """Pack parameters into a flat array.""" 

40 ... 

41 

42 @classmethod 

43 def from_array(cls, x: NDArray[np.float64], **kwargs: Any) -> Self: 

44 """Reconstruct an instance from a flat parameter array.""" 

45 ... 

46 

47 @staticmethod 

48 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]: 

49 """Compute a heuristic initial guess from observed data.""" 

50 ... 

51 

52 

53@dataclass 

54class AbstractSmileModel(ABC): 

55 """Abstract base for dataclass-based smile models. 

56 

57 Provides default ``to_array()`` and ``from_array()`` implementations 

58 that derive serialisation from ``param_names``. Subclasses must define: 

59 

60 - Dataclass fields for the fitted parameters 

61 - ``native_x_coord``, ``native_y_coord``, ``param_names``, ``bounds`` ClassVars 

62 - ``evaluate(x)`` instance method 

63 - ``initial_guess(x, y)`` static method 

64 - ``__post_init__()`` for validation (optional) 

65 """ 

66 

67 native_x_coord: ClassVar[XCoord] 

68 native_y_coord: ClassVar[YCoord] 

69 param_names: ClassVar[tuple[str, ...]] 

70 bounds: ClassVar[tuple[list[float], list[float]]] 

71 

72 def to_array(self) -> NDArray[np.float64]: 

73 """Pack fitted parameters into a flat array using ``param_names`` order.""" 

74 return np.array([getattr(self, name) for name in self.param_names]) 

75 

76 @classmethod 

77 def from_array(cls, x: NDArray[np.float64], **kwargs: Any) -> Self: 

78 """Reconstruct an instance from a flat parameter array. 

79 

80 Fitted parameters are mapped from *x* using ``param_names``. 

81 Additional keyword arguments (e.g. ``expiry``, ``forward``) 

82 are forwarded to the constructor for context fields. 

83 """ 

84 params = {name: float(x[i]) for i, name in enumerate(cls.param_names)} 

85 return cls(**params, **kwargs) 

86 

87 @abstractmethod 

88 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64: 

89 """Compute model output at x values in native coordinates.""" 

90 ... 

91 

92 @staticmethod 

93 @abstractmethod 

94 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]: 

95 """Compute a heuristic initial guess from observed data.""" 

96 ... 

97 

98 

99M = TypeVar("M", bound=SmileModel)