Coverage for src / qsmile / models / sabr.py: 97%
73 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
1"""SABR stochastic volatility model — Hagan et al. (2002) lognormal approximation."""
3from __future__ import annotations
5from dataclasses import dataclass
6from typing import ClassVar
8import numpy as np
9from numpy.typing import ArrayLike, NDArray
11from qsmile.core.coords import XCoord, YCoord
12from qsmile.models.protocol import AbstractSmileModel
15@dataclass
16class SABRModel(AbstractSmileModel):
17 """SABR model with Hagan (2002) lognormal implied volatility approximation.
19 The SABR model describes the dynamics of a forward rate F and its
20 stochastic volatility alpha via:
22 dF = alpha * F^beta * dW_1
23 dalpha = nu * alpha * dW_2
24 <dW_1, dW_2> = rho * dt
26 The Hagan et al. (2002) formula maps these parameters to a
27 closed-form lognormal implied volatility approximation.
29 Fitted parameters (included in the parameter vector):
30 alpha, beta, rho, nu
32 Context fields (NOT included in the parameter vector):
33 expiry, forward
35 Parameters
36 ----------
37 alpha : float
38 Initial volatility. Must be > 0.
39 beta : float
40 CEV exponent. Must be in [0, 1].
41 rho : float
42 Correlation between forward and vol. Must be in (-1, 1).
43 nu : float
44 Vol-of-vol. Must be >= 0.
45 expiry : float
46 Time to expiry in years. Must be > 0.
47 forward : float
48 Forward price. Must be > 0.
49 """
51 alpha: float
52 beta: float
53 rho: float
54 nu: float
55 expiry: float
56 forward: float
58 # -- Class-level model metadata --
60 native_x_coord: ClassVar[XCoord] = XCoord.LogMoneynessStrike
61 native_y_coord: ClassVar[YCoord] = YCoord.Volatility
62 param_names: ClassVar[tuple[str, ...]] = ("alpha", "beta", "rho", "nu")
63 bounds: ClassVar[tuple[list[float], list[float]]] = (
64 [1e-8, 0.0, -0.999, 0.0],
65 [np.inf, 1.0, 0.999, np.inf],
66 )
68 def __post_init__(self) -> None:
69 """Validate SABR parameter constraints."""
70 if self.alpha <= 0:
71 msg = f"alpha must be positive, got {self.alpha}"
72 raise ValueError(msg)
73 if not (0 <= self.beta <= 1):
74 msg = f"beta must be in [0, 1], got {self.beta}"
75 raise ValueError(msg)
76 if not (-1 < self.rho < 1):
77 msg = f"rho must be in (-1, 1), got {self.rho}"
78 raise ValueError(msg)
79 if self.nu < 0:
80 msg = f"nu must be non-negative, got {self.nu}"
81 raise ValueError(msg)
82 if self.expiry <= 0:
83 msg = f"expiry must be positive, got {self.expiry}"
84 raise ValueError(msg)
85 if self.forward <= 0:
86 msg = f"forward must be positive, got {self.forward}"
87 raise ValueError(msg)
89 def evaluate(self, x: ArrayLike) -> NDArray[np.float64] | np.float64:
90 """Compute Hagan (2002) lognormal implied volatility at log-moneyness values.
92 Parameters
93 ----------
94 x : ArrayLike
95 Log-moneyness k = ln(K/F).
97 Returns:
98 -------
99 NDArray[np.float64] | np.float64
100 Implied volatility (lognormal).
101 """
102 k = np.asarray(x, dtype=np.float64)
103 strikes = self.forward * np.exp(k)
104 return self._hagan_implied_vol(self.forward, strikes, self.expiry, self.alpha, self.beta, self.rho, self.nu)
106 @staticmethod
107 def _hagan_implied_vol(
108 fwd: float,
109 strikes: NDArray[np.float64] | float,
110 expiry: float,
111 alpha: float,
112 beta: float,
113 rho: float,
114 nu: float,
115 ) -> NDArray[np.float64] | np.float64:
116 """Hagan et al. (2002) lognormal implied vol approximation.
118 Handles ATM (K ≈ F) and OTM/ITM cases separately for numerical
119 stability.
120 """
121 strikes = np.asarray(strikes, dtype=np.float64)
122 eps = 1e-12
124 # ATM mask
125 is_atm = np.abs(strikes - fwd) < eps * fwd
127 # --- ATM formula ---
128 fb = fwd ** (1 - beta)
129 atm_vol = (alpha / fb) * (
130 1
131 + (
132 ((1 - beta) ** 2 / 24) * alpha**2 / fwd ** (2 * (1 - beta))
133 + 0.25 * rho * beta * nu * alpha / fb
134 + (2 - 3 * rho**2) / 24 * nu**2
135 )
136 * expiry
137 )
139 # --- OTM/ITM formula ---
140 fk_mid = np.where(is_atm, fwd, np.sqrt(fwd * strikes))
141 fk_mid_b = fk_mid ** (1 - beta)
142 log_fk = np.where(is_atm, 0.0, np.log(fwd / strikes))
144 z = np.where(is_atm, 0.0, (nu / alpha) * fk_mid_b * log_fk)
145 x_z = np.where(
146 is_atm,
147 1.0,
148 np.where(
149 np.abs(z) < eps,
150 1.0,
151 z / np.log((np.sqrt(1 - 2 * rho * z + z**2) + z - rho) / (1 - rho + eps)),
152 ),
153 )
155 term1 = alpha / (fk_mid_b * (1 + (1 - beta) ** 2 / 24 * log_fk**2 + (1 - beta) ** 4 / 1920 * log_fk**4))
156 correction = (
157 1
158 + (
159 (1 - beta) ** 2 / 24 * alpha**2 / fk_mid ** (2 * (1 - beta))
160 + 0.25 * rho * beta * nu * alpha / fk_mid_b
161 + (2 - 3 * rho**2) / 24 * nu**2
162 )
163 * expiry
164 )
166 otm_vol = term1 * x_z * correction
168 result = np.where(is_atm, atm_vol, otm_vol)
169 # Clamp to avoid negative implied vol from numerical issues
170 return np.maximum(result, eps)
172 @staticmethod
173 def initial_guess(x: NDArray[np.float64], y: NDArray[np.float64]) -> NDArray[np.float64]:
174 """Compute a heuristic initial guess for SABR parameters from market data.
176 Parameters
177 ----------
178 x : NDArray[np.float64]
179 Log-moneyness values.
180 y : NDArray[np.float64]
181 Observed implied volatility values.
182 """
183 # alpha: ATM implied vol is a reasonable starting point
184 atm_idx = int(np.argmin(np.abs(x)))
185 alpha0 = max(float(y[atm_idx]), 0.01)
187 # beta: start at 0.5 (between normal and lognormal)
188 beta0 = 0.5
190 # rho: estimate skew direction from slope of iv vs moneyness
191 if len(x) >= 3:
192 coeffs = np.polyfit(x, y, 1)
193 rho0 = float(np.clip(coeffs[0] / (alpha0 + 1e-8), -0.9, 0.9))
194 else:
195 rho0 = 0.0
197 # nu: estimate from curvature (smile convexity)
198 if len(x) >= 3:
199 coeffs2 = np.polyfit(x, y, 2)
200 nu0 = max(float(np.abs(coeffs2[0])) * 2, 0.1)
201 else:
202 nu0 = 0.3
204 return np.array([alpha0, beta0, rho0, nu0])