Coverage for src / qsmile / data / vols.py: 100%
123 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
1"""Unified smile data container with coordinate transforms."""
3from __future__ import annotations
5from dataclasses import dataclass, field, replace
6from typing import TYPE_CHECKING
8import numpy as np
9import pandas as pd
10from numpy.typing import ArrayLike, NDArray
12if TYPE_CHECKING:
13 import matplotlib.figure
15from qsmile.core.coords import XCoord, YCoord
16from qsmile.core.maps import (
17 apply_x_chain,
18 apply_y_chain,
19 compose_x_maps,
20 compose_y_maps,
21)
22from qsmile.data.meta import SmileMetadata
23from qsmile.data.strikes import StrikeArray
26@dataclass
27class VolData:
28 """Coordinate-labelled smile data with bid/ask.
30 Parameters
31 ----------
32 strikearray : StrikeArray
33 Strike-indexed data containing at least ``y_bid`` and ``y_ask``
34 columns. Optional ``volume`` and ``open_interest`` columns are
35 supported.
36 current_x_coord : XCoord
37 Which X-coordinate system the data is currently expressed in.
38 current_y_coord : YCoord
39 Which Y-coordinate system the data is currently expressed in.
40 metadata : SmileMetadata
41 Parameters needed by coordinate transforms.
42 """
44 strikearray: StrikeArray
45 current_x_coord: XCoord
46 current_y_coord: YCoord
47 metadata: SmileMetadata
48 _native_x_coord: XCoord = field(init=False, repr=False)
49 _native_y_coord: YCoord = field(init=False, repr=False)
51 def __post_init__(self) -> None:
52 """Validate inputs and record native coordinates."""
53 # Record native coords (first construction sets them)
54 if not hasattr(self, "_native_set"):
55 object.__setattr__(self, "_native_x_coord", self.current_x_coord)
56 object.__setattr__(self, "_native_y_coord", self.current_y_coord)
57 object.__setattr__(self, "_native_set", True)
59 sa = self.strikearray
60 n = len(sa)
62 if n < 3:
63 msg = f"at least 3 data points required, got {n}"
64 raise ValueError(msg)
66 y_bid = sa.get_values(("y", "bid"))
67 y_ask = sa.get_values(("y", "ask"))
69 if y_bid is not None and y_ask is not None and np.any(y_bid > y_ask):
70 msg = "y_bid must not exceed y_ask"
71 raise ValueError(msg)
73 x = sa.strikes
74 if self.current_x_coord in (XCoord.FixedStrike, XCoord.MoneynessStrike) and np.any(x <= 0):
75 msg = f"all x values must be positive for {self.current_x_coord.name}"
76 raise ValueError(msg)
78 if self.current_y_coord in (YCoord.Volatility, YCoord.Variance, YCoord.TotalVariance):
79 for key in (("y", "bid"), ("y", "ask")):
80 arr = sa.get_values(key)
81 if arr is not None and np.any(arr < 0):
82 msg = f"y values must be non-negative for {self.current_y_coord.name}"
83 raise ValueError(msg)
85 for key in (("meta", "volume"), ("meta", "open_interest")):
86 arr = sa.get_values(key)
87 if arr is not None and np.any(arr < 0):
88 msg = f"{key[1]} must be non-negative"
89 raise ValueError(msg)
91 # ── native coordinate properties ──────────────────────────────
93 @property
94 def native_x_coord(self) -> XCoord:
95 """X-coordinate system the data was originally constructed in."""
96 return self._native_x_coord
98 @property
99 def native_y_coord(self) -> YCoord:
100 """Y-coordinate system the data was originally constructed in."""
101 return self._native_y_coord
103 # ── convenience accessors (lazy transform) ────────────────────
105 def _is_native(self) -> bool:
106 """True if current coords match native coords."""
107 return self.current_x_coord == self._native_x_coord and self.current_y_coord == self._native_y_coord
109 @property
110 def x(self) -> NDArray[np.float64]:
111 """X-coordinate values in current coordinate system."""
112 native_x = self.strikearray.strikes
113 if self._is_native():
114 return native_x
115 x_chain = compose_x_maps(self._native_x_coord, self.current_x_coord)
116 return apply_x_chain(native_x, x_chain, self.metadata)
118 @property
119 def y_bid(self) -> NDArray[np.float64]:
120 """Y-coordinate bid values in current coordinate system."""
121 native_bid = self.strikearray.values(("y", "bid"))
122 if self._is_native():
123 return native_bid
124 native_x = self.strikearray.strikes
125 y_chain = compose_y_maps(self._native_y_coord, self.current_y_coord)
126 return apply_y_chain(
127 native_bid,
128 native_x,
129 y_chain,
130 self.metadata,
131 self._native_x_coord,
132 self.current_x_coord,
133 )
135 @property
136 def y_ask(self) -> NDArray[np.float64]:
137 """Y-coordinate ask values in current coordinate system."""
138 native_ask = self.strikearray.values(("y", "ask"))
139 if self._is_native():
140 return native_ask
141 native_x = self.strikearray.strikes
142 y_chain = compose_y_maps(self._native_y_coord, self.current_y_coord)
143 return apply_y_chain(
144 native_ask,
145 native_x,
146 y_chain,
147 self.metadata,
148 self._native_x_coord,
149 self.current_x_coord,
150 )
152 @property
153 def volume(self) -> NDArray[np.float64] | None:
154 """Per-point traded volume, or None."""
155 return self.strikearray.get_values(("meta", "volume"))
157 @property
158 def open_interest(self) -> NDArray[np.float64] | None:
159 """Per-point open interest, or None."""
160 return self.strikearray.get_values(("meta", "open_interest"))
162 @property
163 def y_mid(self) -> NDArray[np.float64]:
164 """Midpoint of bid and ask Y values in current coordinate system."""
165 return (self.y_bid + self.y_ask) / 2.0
167 def transform(self, target_x: XCoord, target_y: YCoord) -> VolData:
168 """Return a copy expressed in the target coordinate system.
170 This is lightweight: it shares the same underlying StrikeArray
171 and only updates the current coordinate labels. Property
172 accessors apply transforms lazily on access.
174 Parameters
175 ----------
176 target_x : XCoord
177 Target X-coordinate system.
178 target_y : YCoord
179 Target Y-coordinate system.
181 Returns:
182 -------
183 VolData
184 New VolData in the target coordinates.
185 """
186 new = VolData.__new__(VolData)
187 object.__setattr__(new, "strikearray", self.strikearray)
188 object.__setattr__(new, "current_x_coord", target_x)
189 object.__setattr__(new, "current_y_coord", target_y)
190 object.__setattr__(new, "metadata", self.metadata)
191 object.__setattr__(new, "_native_x_coord", self._native_x_coord)
192 object.__setattr__(new, "_native_y_coord", self._native_y_coord)
193 object.__setattr__(new, "_native_set", True)
194 return new
196 @classmethod
197 def from_mid_vols(
198 cls,
199 strikes: NDArray[np.float64],
200 ivs: NDArray[np.float64],
201 metadata: SmileMetadata,
202 ) -> VolData:
203 """Create from mid implied vols (setting y_bid = y_ask = ivs).
205 Parameters
206 ----------
207 strikes : NDArray[np.float64]
208 Strike prices.
209 ivs : NDArray[np.float64]
210 Mid implied volatilities.
211 metadata : SmileMetadata
212 Smile metadata. ``metadata.forward`` must not be ``None``.
213 ``sigma_atm`` is always recomputed from the data.
214 """
215 strikes = np.asarray(strikes, dtype=np.float64)
216 ivs = np.asarray(ivs, dtype=np.float64)
218 if metadata.forward is None:
219 msg = "metadata.forward must not be None"
220 raise TypeError(msg)
222 atm_idx = int(np.argmin(np.abs(strikes - metadata.forward)))
223 sigma_atm = float(ivs[atm_idx])
224 meta = replace(metadata, sigma_atm=sigma_atm)
226 sa = StrikeArray()
227 idx = pd.Index(strikes, dtype=np.float64)
228 sa.set(("y", "bid"), pd.Series(ivs, index=idx))
229 sa.set(("y", "ask"), pd.Series(ivs.copy(), index=idx))
231 return cls(
232 strikearray=sa,
233 current_x_coord=XCoord.FixedStrike,
234 current_y_coord=YCoord.Volatility,
235 metadata=meta,
236 )
238 def evaluate(self, x: ArrayLike) -> NDArray[np.float64]:
239 """Interpolate mid-smile at arbitrary x in current coordinates.
241 Uses cubic spline interpolation on ``y_mid``. Returns ``NaN``
242 for points outside the data domain (no extrapolation).
244 Parameters
245 ----------
246 x : ArrayLike
247 X values in the current coordinate system.
249 Returns:
250 -------
251 NDArray[np.float64]
252 Interpolated mid Y values.
253 """
254 from scipy.interpolate import CubicSpline
256 x_arr = np.asarray(x, dtype=np.float64)
257 current_x = self.x
258 current_y_mid = self.y_mid
260 cs = CubicSpline(current_x, current_y_mid, extrapolate=False)
261 return np.asarray(cs(x_arr), dtype=np.float64)
263 def plot(self, *, title: str = "Smile Data", ax=None, color="k", **kwargs) -> matplotlib.figure.Figure:
264 """Plot bid/ask Y-values as error bars vs X.
266 Axis labels are derived from coordinate names.
267 """
268 from qsmile.core.plot import plot_bid_ask
270 return plot_bid_ask(
271 self.x,
272 self.y_mid,
273 self.y_bid,
274 self.y_ask,
275 xlabel=self.current_x_coord.name,
276 ylabel=self.current_y_coord.name,
277 title=title,
278 fmt="none",
279 color=color,
280 ax=ax,
281 **kwargs,
282 )