Coverage for src / qsmile / data / vols.py: 100%

123 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 22:47 +0000

1"""Unified smile data container with coordinate transforms.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass, field, replace 

6from typing import TYPE_CHECKING 

7 

8import numpy as np 

9import pandas as pd 

10from numpy.typing import ArrayLike, NDArray 

11 

12if TYPE_CHECKING: 

13 import matplotlib.figure 

14 

15from qsmile.core.coords import XCoord, YCoord 

16from qsmile.core.maps import ( 

17 apply_x_chain, 

18 apply_y_chain, 

19 compose_x_maps, 

20 compose_y_maps, 

21) 

22from qsmile.data.meta import SmileMetadata 

23from qsmile.data.strikes import StrikeArray 

24 

25 

26@dataclass 

27class VolData: 

28 """Coordinate-labelled smile data with bid/ask. 

29 

30 Parameters 

31 ---------- 

32 strikearray : StrikeArray 

33 Strike-indexed data containing at least ``y_bid`` and ``y_ask`` 

34 columns. Optional ``volume`` and ``open_interest`` columns are 

35 supported. 

36 current_x_coord : XCoord 

37 Which X-coordinate system the data is currently expressed in. 

38 current_y_coord : YCoord 

39 Which Y-coordinate system the data is currently expressed in. 

40 metadata : SmileMetadata 

41 Parameters needed by coordinate transforms. 

42 """ 

43 

44 strikearray: StrikeArray 

45 current_x_coord: XCoord 

46 current_y_coord: YCoord 

47 metadata: SmileMetadata 

48 _native_x_coord: XCoord = field(init=False, repr=False) 

49 _native_y_coord: YCoord = field(init=False, repr=False) 

50 

51 def __post_init__(self) -> None: 

52 """Validate inputs and record native coordinates.""" 

53 # Record native coords (first construction sets them) 

54 if not hasattr(self, "_native_set"): 

55 object.__setattr__(self, "_native_x_coord", self.current_x_coord) 

56 object.__setattr__(self, "_native_y_coord", self.current_y_coord) 

57 object.__setattr__(self, "_native_set", True) 

58 

59 sa = self.strikearray 

60 n = len(sa) 

61 

62 if n < 3: 

63 msg = f"at least 3 data points required, got {n}" 

64 raise ValueError(msg) 

65 

66 y_bid = sa.get_values(("y", "bid")) 

67 y_ask = sa.get_values(("y", "ask")) 

68 

69 if y_bid is not None and y_ask is not None and np.any(y_bid > y_ask): 

70 msg = "y_bid must not exceed y_ask" 

71 raise ValueError(msg) 

72 

73 x = sa.strikes 

74 if self.current_x_coord in (XCoord.FixedStrike, XCoord.MoneynessStrike) and np.any(x <= 0): 

75 msg = f"all x values must be positive for {self.current_x_coord.name}" 

76 raise ValueError(msg) 

77 

78 if self.current_y_coord in (YCoord.Volatility, YCoord.Variance, YCoord.TotalVariance): 

79 for key in (("y", "bid"), ("y", "ask")): 

80 arr = sa.get_values(key) 

81 if arr is not None and np.any(arr < 0): 

82 msg = f"y values must be non-negative for {self.current_y_coord.name}" 

83 raise ValueError(msg) 

84 

85 for key in (("meta", "volume"), ("meta", "open_interest")): 

86 arr = sa.get_values(key) 

87 if arr is not None and np.any(arr < 0): 

88 msg = f"{key[1]} must be non-negative" 

89 raise ValueError(msg) 

90 

91 # ── native coordinate properties ────────────────────────────── 

92 

93 @property 

94 def native_x_coord(self) -> XCoord: 

95 """X-coordinate system the data was originally constructed in.""" 

96 return self._native_x_coord 

97 

98 @property 

99 def native_y_coord(self) -> YCoord: 

100 """Y-coordinate system the data was originally constructed in.""" 

101 return self._native_y_coord 

102 

103 # ── convenience accessors (lazy transform) ──────────────────── 

104 

105 def _is_native(self) -> bool: 

106 """True if current coords match native coords.""" 

107 return self.current_x_coord == self._native_x_coord and self.current_y_coord == self._native_y_coord 

108 

109 @property 

110 def x(self) -> NDArray[np.float64]: 

111 """X-coordinate values in current coordinate system.""" 

112 native_x = self.strikearray.strikes 

113 if self._is_native(): 

114 return native_x 

115 x_chain = compose_x_maps(self._native_x_coord, self.current_x_coord) 

116 return apply_x_chain(native_x, x_chain, self.metadata) 

117 

118 @property 

119 def y_bid(self) -> NDArray[np.float64]: 

120 """Y-coordinate bid values in current coordinate system.""" 

121 native_bid = self.strikearray.values(("y", "bid")) 

122 if self._is_native(): 

123 return native_bid 

124 native_x = self.strikearray.strikes 

125 y_chain = compose_y_maps(self._native_y_coord, self.current_y_coord) 

126 return apply_y_chain( 

127 native_bid, 

128 native_x, 

129 y_chain, 

130 self.metadata, 

131 self._native_x_coord, 

132 self.current_x_coord, 

133 ) 

134 

135 @property 

136 def y_ask(self) -> NDArray[np.float64]: 

137 """Y-coordinate ask values in current coordinate system.""" 

138 native_ask = self.strikearray.values(("y", "ask")) 

139 if self._is_native(): 

140 return native_ask 

141 native_x = self.strikearray.strikes 

142 y_chain = compose_y_maps(self._native_y_coord, self.current_y_coord) 

143 return apply_y_chain( 

144 native_ask, 

145 native_x, 

146 y_chain, 

147 self.metadata, 

148 self._native_x_coord, 

149 self.current_x_coord, 

150 ) 

151 

152 @property 

153 def volume(self) -> NDArray[np.float64] | None: 

154 """Per-point traded volume, or None.""" 

155 return self.strikearray.get_values(("meta", "volume")) 

156 

157 @property 

158 def open_interest(self) -> NDArray[np.float64] | None: 

159 """Per-point open interest, or None.""" 

160 return self.strikearray.get_values(("meta", "open_interest")) 

161 

162 @property 

163 def y_mid(self) -> NDArray[np.float64]: 

164 """Midpoint of bid and ask Y values in current coordinate system.""" 

165 return (self.y_bid + self.y_ask) / 2.0 

166 

167 def transform(self, target_x: XCoord, target_y: YCoord) -> VolData: 

168 """Return a copy expressed in the target coordinate system. 

169 

170 This is lightweight: it shares the same underlying StrikeArray 

171 and only updates the current coordinate labels. Property 

172 accessors apply transforms lazily on access. 

173 

174 Parameters 

175 ---------- 

176 target_x : XCoord 

177 Target X-coordinate system. 

178 target_y : YCoord 

179 Target Y-coordinate system. 

180 

181 Returns: 

182 ------- 

183 VolData 

184 New VolData in the target coordinates. 

185 """ 

186 new = VolData.__new__(VolData) 

187 object.__setattr__(new, "strikearray", self.strikearray) 

188 object.__setattr__(new, "current_x_coord", target_x) 

189 object.__setattr__(new, "current_y_coord", target_y) 

190 object.__setattr__(new, "metadata", self.metadata) 

191 object.__setattr__(new, "_native_x_coord", self._native_x_coord) 

192 object.__setattr__(new, "_native_y_coord", self._native_y_coord) 

193 object.__setattr__(new, "_native_set", True) 

194 return new 

195 

196 @classmethod 

197 def from_mid_vols( 

198 cls, 

199 strikes: NDArray[np.float64], 

200 ivs: NDArray[np.float64], 

201 metadata: SmileMetadata, 

202 ) -> VolData: 

203 """Create from mid implied vols (setting y_bid = y_ask = ivs). 

204 

205 Parameters 

206 ---------- 

207 strikes : NDArray[np.float64] 

208 Strike prices. 

209 ivs : NDArray[np.float64] 

210 Mid implied volatilities. 

211 metadata : SmileMetadata 

212 Smile metadata. ``metadata.forward`` must not be ``None``. 

213 ``sigma_atm`` is always recomputed from the data. 

214 """ 

215 strikes = np.asarray(strikes, dtype=np.float64) 

216 ivs = np.asarray(ivs, dtype=np.float64) 

217 

218 if metadata.forward is None: 

219 msg = "metadata.forward must not be None" 

220 raise TypeError(msg) 

221 

222 atm_idx = int(np.argmin(np.abs(strikes - metadata.forward))) 

223 sigma_atm = float(ivs[atm_idx]) 

224 meta = replace(metadata, sigma_atm=sigma_atm) 

225 

226 sa = StrikeArray() 

227 idx = pd.Index(strikes, dtype=np.float64) 

228 sa.set(("y", "bid"), pd.Series(ivs, index=idx)) 

229 sa.set(("y", "ask"), pd.Series(ivs.copy(), index=idx)) 

230 

231 return cls( 

232 strikearray=sa, 

233 current_x_coord=XCoord.FixedStrike, 

234 current_y_coord=YCoord.Volatility, 

235 metadata=meta, 

236 ) 

237 

238 def evaluate(self, x: ArrayLike) -> NDArray[np.float64]: 

239 """Interpolate mid-smile at arbitrary x in current coordinates. 

240 

241 Uses cubic spline interpolation on ``y_mid``. Returns ``NaN`` 

242 for points outside the data domain (no extrapolation). 

243 

244 Parameters 

245 ---------- 

246 x : ArrayLike 

247 X values in the current coordinate system. 

248 

249 Returns: 

250 ------- 

251 NDArray[np.float64] 

252 Interpolated mid Y values. 

253 """ 

254 from scipy.interpolate import CubicSpline 

255 

256 x_arr = np.asarray(x, dtype=np.float64) 

257 current_x = self.x 

258 current_y_mid = self.y_mid 

259 

260 cs = CubicSpline(current_x, current_y_mid, extrapolate=False) 

261 return np.asarray(cs(x_arr), dtype=np.float64) 

262 

263 def plot(self, *, title: str = "Smile Data", ax=None, color="k", **kwargs) -> matplotlib.figure.Figure: 

264 """Plot bid/ask Y-values as error bars vs X. 

265 

266 Axis labels are derived from coordinate names. 

267 """ 

268 from qsmile.core.plot import plot_bid_ask 

269 

270 return plot_bid_ask( 

271 self.x, 

272 self.y_mid, 

273 self.y_bid, 

274 self.y_ask, 

275 xlabel=self.current_x_coord.name, 

276 ylabel=self.current_y_coord.name, 

277 title=title, 

278 fmt="none", 

279 color=color, 

280 ax=ax, 

281 **kwargs, 

282 )