Coverage for src / qsmile / models / svi.py: 96%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-04 21:47 +0000

1"""SVI (Stochastic Volatility Inspired) raw parameterisation.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass 

6from typing import ClassVar 

7 

8import numpy as np 

9from numpy.typing import ArrayLike, NDArray 

10 

11from qsmile.core.coords import XCoord, YCoord 

12from qsmile.models.protocol import AbstractSmileModel 

13 

14 

15@dataclass 

16class SVIModel(AbstractSmileModel): 

17 """Raw SVI parameterisation: model definition and fitted parameters. 

18 

19 The SVI raw parameterisation models total implied variance as: 

20 

21 w(k) = a + b * (rho * (k - m) + sqrt((k - m)^2 + sigma^2)) 

22 

23 where k = ln(K/F) is log-moneyness. 

24 

25 Pass this class to ``fit()`` as the model, and receive instances 

26 back as fitted parameters:: 

27 

28 result = fit(sd, model=SVIModel) 

29 result.params # → SVIModel instance 

30 result.params.evaluate(k) 

31 

32 Parameters 

33 ---------- 

34 a : float 

35 Vertical translation (overall variance level). 

36 b : float 

37 Slope of the wings. Must be >= 0. 

38 rho : float 

39 Correlation / rotation. Must be in (-1, 1). 

40 m : float 

41 Horizontal translation (log-moneyness shift). 

42 sigma : float 

43 Curvature at the vertex. Must be > 0. 

44 """ 

45 

46 a: float 

47 b: float 

48 rho: float 

49 m: float 

50 sigma: float 

51 

52 # -- Class-level model metadata (excluded from dataclass fields) -- 

53 

54 native_x_coord: ClassVar[XCoord] = XCoord.LogMoneynessStrike 

55 native_y_coord: ClassVar[YCoord] = YCoord.TotalVariance 

56 param_names: ClassVar[tuple[str, ...]] = ("a", "b", "rho", "m", "sigma") 

57 bounds: ClassVar[tuple[list[float], list[float]]] = ( 

58 [-np.inf, 0.0, -0.999, -np.inf, 1e-8], 

59 [np.inf, np.inf, 0.999, np.inf, np.inf], 

60 ) 

61 

62 def __post_init__(self) -> None: 

63 """Validate SVI parameter constraints.""" 

64 if self.b < 0: 

65 msg = f"b must be non-negative, got {self.b}" 

66 raise ValueError(msg) 

67 if not (-1 < self.rho < 1): 

68 msg = f"rho must be in (-1, 1), got {self.rho}" 

69 raise ValueError(msg) 

70 if self.sigma <= 0: 

71 msg = f"sigma must be positive, got {self.sigma}" 

72 raise ValueError(msg) 

73 

74 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64: 

75 """Compute SVI total variance at the given log-moneyness values. 

76 

77 w(k) = a + b * (rho * (k - m) + sqrt((k - m)^2 + sigma^2)) 

78 """ 

79 k = np.asarray(x, dtype=np.float64) 

80 d = k - self.m 

81 return self.a + self.b * (self.rho * d + np.sqrt(d**2 + self.sigma**2)) 

82 

83 def implied_vol(self, k: ArrayLike, expiry: float) -> NDArray[np.float64] | np.float64: 

84 """Compute SVI implied volatility from total variance. 

85 

86 sigma_IV = sqrt(w(k) / T) 

87 

88 Parameters 

89 ---------- 

90 k : ArrayLike 

91 Log-moneyness values. 

92 expiry : float 

93 Time to expiration in years. Must be positive. 

94 """ 

95 if expiry <= 0: 

96 msg = f"expiry must be positive, got {expiry}" 

97 raise ValueError(msg) 

98 w = self.evaluate(k) 

99 return np.sqrt(w / expiry) 

100 

101 @staticmethod 

102 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]: 

103 """Compute a heuristic initial guess for SVI parameters from market data. 

104 

105 Parameters 

106 ---------- 

107 x : NDArray[np.float64] 

108 Log-moneyness values. 

109 y : NDArray[np.float64] 

110 Observed total variance values. 

111 """ 

112 # a: ATM total variance (closest to k=0) 

113 atm_idx = int(np.argmin(np.abs(x))) 

114 a0 = float(y[atm_idx]) 

115 

116 # Estimate slope and curvature from a quadratic fit: w ≈ c0 + c1*k + c2*k² 

117 if len(x) >= 3: 

118 coeffs = np.polyfit(x, y, 2) 

119 c2, c1, _c0 = coeffs 

120 b0 = max(abs(c1) + 2 * abs(c2), 0.01) 

121 rho0 = np.clip(c1 / b0, -0.9, 0.9) 

122 else: 

123 b0 = max(float(np.std(y)) * 2, 0.01) 

124 rho0 = 0.0 

125 

126 m0 = float(x[atm_idx]) 

127 sigma0 = max(float(np.std(x)) * 0.5, 0.01) 

128 

129 return np.array([a0, b0, rho0, m0, sigma0])