Coverage for src / qsmile / models / svi.py: 96%
52 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
1"""SVI (Stochastic Volatility Inspired) raw parameterisation."""
3from __future__ import annotations
5from dataclasses import dataclass
6from typing import ClassVar
8import numpy as np
9from numpy.typing import ArrayLike, NDArray
11from qsmile.core.coords import XCoord, YCoord
12from qsmile.models.protocol import AbstractSmileModel
15@dataclass
16class SVIModel(AbstractSmileModel):
17 """Raw SVI parameterisation: model definition and fitted parameters.
19 The SVI raw parameterisation models total implied variance as:
21 w(k) = a + b * (rho * (k - m) + sqrt((k - m)^2 + sigma^2))
23 where k = ln(K/F) is log-moneyness.
25 Pass this class to ``fit()`` as the model, and receive instances
26 back as fitted parameters::
28 result = fit(sd, model=SVIModel)
29 result.params # → SVIModel instance
30 result.params.evaluate(k)
32 Parameters
33 ----------
34 a : float
35 Vertical translation (overall variance level).
36 b : float
37 Slope of the wings. Must be >= 0.
38 rho : float
39 Correlation / rotation. Must be in (-1, 1).
40 m : float
41 Horizontal translation (log-moneyness shift).
42 sigma : float
43 Curvature at the vertex. Must be > 0.
44 """
46 a: float
47 b: float
48 rho: float
49 m: float
50 sigma: float
52 # -- Class-level model metadata (excluded from dataclass fields) --
54 native_x_coord: ClassVar[XCoord] = XCoord.LogMoneynessStrike
55 native_y_coord: ClassVar[YCoord] = YCoord.TotalVariance
56 param_names: ClassVar[tuple[str, ...]] = ("a", "b", "rho", "m", "sigma")
57 bounds: ClassVar[tuple[list[float], list[float]]] = (
58 [-np.inf, 0.0, -0.999, -np.inf, 1e-8],
59 [np.inf, np.inf, 0.999, np.inf, np.inf],
60 )
62 def __post_init__(self) -> None:
63 """Validate SVI parameter constraints."""
64 if self.b < 0:
65 msg = f"b must be non-negative, got {self.b}"
66 raise ValueError(msg)
67 if not (-1 < self.rho < 1):
68 msg = f"rho must be in (-1, 1), got {self.rho}"
69 raise ValueError(msg)
70 if self.sigma <= 0:
71 msg = f"sigma must be positive, got {self.sigma}"
72 raise ValueError(msg)
74 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
75 """Compute SVI total variance at the given log-moneyness values.
77 w(k) = a + b * (rho * (k - m) + sqrt((k - m)^2 + sigma^2))
78 """
79 k = np.asarray(x, dtype=np.float64)
80 d = k - self.m
81 return self.a + self.b * (self.rho * d + np.sqrt(d**2 + self.sigma**2))
83 def implied_vol(self, k: ArrayLike, expiry: float) -> NDArray[np.float64] | np.float64:
84 """Compute SVI implied volatility from total variance.
86 sigma_IV = sqrt(w(k) / T)
88 Parameters
89 ----------
90 k : ArrayLike
91 Log-moneyness values.
92 expiry : float
93 Time to expiration in years. Must be positive.
94 """
95 if expiry <= 0:
96 msg = f"expiry must be positive, got {expiry}"
97 raise ValueError(msg)
98 w = self.evaluate(k)
99 return np.sqrt(w / expiry)
101 @staticmethod
102 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]:
103 """Compute a heuristic initial guess for SVI parameters from market data.
105 Parameters
106 ----------
107 x : NDArray[np.float64]
108 Log-moneyness values.
109 y : NDArray[np.float64]
110 Observed total variance values.
111 """
112 # a: ATM total variance (closest to k=0)
113 atm_idx = int(np.argmin(np.abs(x)))
114 a0 = float(y[atm_idx])
116 # Estimate slope and curvature from a quadratic fit: w ≈ c0 + c1*k + c2*k²
117 if len(x) >= 3:
118 coeffs = np.polyfit(x, y, 2)
119 c2, c1, _c0 = coeffs
120 b0 = max(abs(c1) + 2 * abs(c2), 0.01)
121 rho0 = np.clip(c1 / b0, -0.9, 0.9)
122 else:
123 b0 = max(float(np.std(y)) * 2, 0.01)
124 rho0 = 0.0
126 m0 = float(x[atm_idx])
127 sigma0 = max(float(np.std(x)) * 0.5, 0.01)
129 return np.array([a0, b0, rho0, m0, sigma0])