Coverage for src / qsmile / data / vols.py: 100%
61 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
1"""Unified smile data container with coordinate transforms."""
3from __future__ import annotations
5from dataclasses import dataclass, replace
6from typing import TYPE_CHECKING
8import numpy as np
9from numpy.typing import NDArray
11if TYPE_CHECKING:
12 import matplotlib.figure
14from qsmile.core.coords import XCoord, YCoord
15from qsmile.core.maps import (
16 apply_x_chain,
17 apply_y_chain,
18 compose_x_maps,
19 compose_y_maps,
20)
21from qsmile.data.meta import SmileMetadata
24@dataclass
25class SmileData:
26 """Coordinate-labelled smile data with bid/ask.
28 Parameters
29 ----------
30 x : NDArray[np.float64]
31 X-coordinate values.
32 y_bid : NDArray[np.float64]
33 Y-coordinate bid values.
34 y_ask : NDArray[np.float64]
35 Y-coordinate ask values.
36 x_coord : XCoord
37 Which X-coordinate system the data is in.
38 y_coord : YCoord
39 Which Y-coordinate system the data is in.
40 metadata : SmileMetadata
41 Parameters needed by coordinate transforms.
42 """
44 x: NDArray[np.float64]
45 y_bid: NDArray[np.float64]
46 y_ask: NDArray[np.float64]
47 x_coord: XCoord
48 y_coord: YCoord
49 metadata: SmileMetadata
51 def __post_init__(self) -> None:
52 """Validate and convert inputs."""
53 self.x = np.asarray(self.x, dtype=np.float64)
54 self.y_bid = np.asarray(self.y_bid, dtype=np.float64)
55 self.y_ask = np.asarray(self.y_ask, dtype=np.float64)
57 n = len(self.x)
58 if len(self.y_bid) != n or len(self.y_ask) != n:
59 msg = (
60 f"all arrays must have the same length as x ({n}), got y_bid={len(self.y_bid)}, y_ask={len(self.y_ask)}"
61 )
62 raise ValueError(msg)
64 if n < 3:
65 msg = f"at least 3 data points required, got {n}"
66 raise ValueError(msg)
68 if np.any(self.y_bid > self.y_ask):
69 msg = "y_bid must not exceed y_ask"
70 raise ValueError(msg)
72 if self.x_coord in (XCoord.FixedStrike, XCoord.MoneynessStrike) and np.any(self.x <= 0):
73 msg = f"all x values must be positive for {self.x_coord.name}"
74 raise ValueError(msg)
76 if self.y_coord in (YCoord.Volatility, YCoord.Variance, YCoord.TotalVariance) and (
77 np.any(self.y_bid < 0) or np.any(self.y_ask < 0)
78 ):
79 msg = f"y values must be non-negative for {self.y_coord.name}"
80 raise ValueError(msg)
82 @property
83 def y_mid(self) -> NDArray[np.float64]:
84 """Midpoint of bid and ask Y values."""
85 return (self.y_bid + self.y_ask) / 2.0
87 def transform(self, target_x: XCoord, target_y: YCoord) -> SmileData:
88 """Re-express data in target coordinate system.
90 Parameters
91 ----------
92 target_x : XCoord
93 Target X-coordinate system.
94 target_y : YCoord
95 Target Y-coordinate system.
97 Returns:
98 -------
99 SmileData
100 New SmileData in the target coordinates.
101 """
102 # Transform X
103 x_chain = compose_x_maps(self.x_coord, target_x)
104 new_x = apply_x_chain(self.x, x_chain, self.metadata)
106 # Transform Y (bid and ask independently)
107 y_chain = compose_y_maps(self.y_coord, target_y)
108 new_y_bid = apply_y_chain(self.y_bid, self.x, y_chain, self.metadata, self.x_coord, target_x)
109 new_y_ask = apply_y_chain(self.y_ask, self.x, y_chain, self.metadata, self.x_coord, target_x)
111 # If we now have vols in FixedStrike and sigma_atm is missing, derive it
112 metadata = self.metadata
113 if target_y == YCoord.Volatility and target_x == XCoord.FixedStrike and metadata.sigma_atm is None:
114 atm_idx = int(np.argmin(np.abs(new_x - metadata.forward)))
115 sigma_atm = float((new_y_bid[atm_idx] + new_y_ask[atm_idx]) / 2.0)
116 metadata = replace(metadata, sigma_atm=sigma_atm)
118 return SmileData(
119 x=new_x,
120 y_bid=new_y_bid,
121 y_ask=new_y_ask,
122 x_coord=target_x,
123 y_coord=target_y,
124 metadata=metadata,
125 )
127 @classmethod
128 def from_mid_vols(
129 cls,
130 strikes: NDArray[np.float64],
131 ivs: NDArray[np.float64],
132 forward: float,
133 expiry: float,
134 discount_factor: float = 1.0,
135 ) -> SmileData:
136 """Create from mid implied vols (setting y_bid = y_ask = ivs).
138 Parameters
139 ----------
140 strikes : NDArray[np.float64]
141 Strike prices.
142 ivs : NDArray[np.float64]
143 Mid implied volatilities.
144 forward : float
145 Forward price.
146 expiry : float
147 Time to expiry in years.
148 discount_factor : float
149 Discount factor, defaults to 1.0.
150 """
151 strikes = np.asarray(strikes, dtype=np.float64)
152 ivs = np.asarray(ivs, dtype=np.float64)
153 atm_idx = int(np.argmin(np.abs(strikes - forward)))
154 sigma_atm = float(ivs[atm_idx])
155 return cls(
156 x=strikes,
157 y_bid=ivs,
158 y_ask=ivs.copy(),
159 x_coord=XCoord.FixedStrike,
160 y_coord=YCoord.Volatility,
161 metadata=SmileMetadata(
162 forward=forward,
163 discount_factor=discount_factor,
164 expiry=expiry,
165 sigma_atm=sigma_atm,
166 ),
167 )
169 def plot(self, *, title: str = "Smile Data") -> matplotlib.figure.Figure:
170 """Plot bid/ask Y-values as error bars vs X.
172 Axis labels are derived from coordinate names.
173 """
174 from qsmile.core.plot import plot_bid_ask
176 return plot_bid_ask(
177 self.x,
178 self.y_mid,
179 self.y_bid,
180 self.y_ask,
181 xlabel=self.x_coord.name,
182 ylabel=self.y_coord.name,
183 title=title,
184 )