Coverage for src / qsmile / data / vols.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-04 21:47 +0000

1"""Unified smile data container with coordinate transforms.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass, replace 

6from typing import TYPE_CHECKING 

7 

8import numpy as np 

9from numpy.typing import NDArray 

10 

11if TYPE_CHECKING: 

12 import matplotlib.figure 

13 

14from qsmile.core.coords import XCoord, YCoord 

15from qsmile.core.maps import ( 

16 apply_x_chain, 

17 apply_y_chain, 

18 compose_x_maps, 

19 compose_y_maps, 

20) 

21from qsmile.data.meta import SmileMetadata 

22 

23 

24@dataclass 

25class SmileData: 

26 """Coordinate-labelled smile data with bid/ask. 

27 

28 Parameters 

29 ---------- 

30 x : NDArray[np.float64] 

31 X-coordinate values. 

32 y_bid : NDArray[np.float64] 

33 Y-coordinate bid values. 

34 y_ask : NDArray[np.float64] 

35 Y-coordinate ask values. 

36 x_coord : XCoord 

37 Which X-coordinate system the data is in. 

38 y_coord : YCoord 

39 Which Y-coordinate system the data is in. 

40 metadata : SmileMetadata 

41 Parameters needed by coordinate transforms. 

42 """ 

43 

44 x: NDArray[np.float64] 

45 y_bid: NDArray[np.float64] 

46 y_ask: NDArray[np.float64] 

47 x_coord: XCoord 

48 y_coord: YCoord 

49 metadata: SmileMetadata 

50 

51 def __post_init__(self) -> None: 

52 """Validate and convert inputs.""" 

53 self.x = np.asarray(self.x, dtype=np.float64) 

54 self.y_bid = np.asarray(self.y_bid, dtype=np.float64) 

55 self.y_ask = np.asarray(self.y_ask, dtype=np.float64) 

56 

57 n = len(self.x) 

58 if len(self.y_bid) != n or len(self.y_ask) != n: 

59 msg = ( 

60 f"all arrays must have the same length as x ({n}), got y_bid={len(self.y_bid)}, y_ask={len(self.y_ask)}" 

61 ) 

62 raise ValueError(msg) 

63 

64 if n < 3: 

65 msg = f"at least 3 data points required, got {n}" 

66 raise ValueError(msg) 

67 

68 if np.any(self.y_bid > self.y_ask): 

69 msg = "y_bid must not exceed y_ask" 

70 raise ValueError(msg) 

71 

72 if self.x_coord in (XCoord.FixedStrike, XCoord.MoneynessStrike) and np.any(self.x <= 0): 

73 msg = f"all x values must be positive for {self.x_coord.name}" 

74 raise ValueError(msg) 

75 

76 if self.y_coord in (YCoord.Volatility, YCoord.Variance, YCoord.TotalVariance) and ( 

77 np.any(self.y_bid < 0) or np.any(self.y_ask < 0) 

78 ): 

79 msg = f"y values must be non-negative for {self.y_coord.name}" 

80 raise ValueError(msg) 

81 

82 @property 

83 def y_mid(self) -> NDArray[np.float64]: 

84 """Midpoint of bid and ask Y values.""" 

85 return (self.y_bid + self.y_ask) / 2.0 

86 

87 def transform(self, target_x: XCoord, target_y: YCoord) -> SmileData: 

88 """Re-express data in target coordinate system. 

89 

90 Parameters 

91 ---------- 

92 target_x : XCoord 

93 Target X-coordinate system. 

94 target_y : YCoord 

95 Target Y-coordinate system. 

96 

97 Returns: 

98 ------- 

99 SmileData 

100 New SmileData in the target coordinates. 

101 """ 

102 # Transform X 

103 x_chain = compose_x_maps(self.x_coord, target_x) 

104 new_x = apply_x_chain(self.x, x_chain, self.metadata) 

105 

106 # Transform Y (bid and ask independently) 

107 y_chain = compose_y_maps(self.y_coord, target_y) 

108 new_y_bid = apply_y_chain(self.y_bid, self.x, y_chain, self.metadata, self.x_coord, target_x) 

109 new_y_ask = apply_y_chain(self.y_ask, self.x, y_chain, self.metadata, self.x_coord, target_x) 

110 

111 # If we now have vols in FixedStrike and sigma_atm is missing, derive it 

112 metadata = self.metadata 

113 if target_y == YCoord.Volatility and target_x == XCoord.FixedStrike and metadata.sigma_atm is None: 

114 atm_idx = int(np.argmin(np.abs(new_x - metadata.forward))) 

115 sigma_atm = float((new_y_bid[atm_idx] + new_y_ask[atm_idx]) / 2.0) 

116 metadata = replace(metadata, sigma_atm=sigma_atm) 

117 

118 return SmileData( 

119 x=new_x, 

120 y_bid=new_y_bid, 

121 y_ask=new_y_ask, 

122 x_coord=target_x, 

123 y_coord=target_y, 

124 metadata=metadata, 

125 ) 

126 

127 @classmethod 

128 def from_mid_vols( 

129 cls, 

130 strikes: NDArray[np.float64], 

131 ivs: NDArray[np.float64], 

132 forward: float, 

133 expiry: float, 

134 discount_factor: float = 1.0, 

135 ) -> SmileData: 

136 """Create from mid implied vols (setting y_bid = y_ask = ivs). 

137 

138 Parameters 

139 ---------- 

140 strikes : NDArray[np.float64] 

141 Strike prices. 

142 ivs : NDArray[np.float64] 

143 Mid implied volatilities. 

144 forward : float 

145 Forward price. 

146 expiry : float 

147 Time to expiry in years. 

148 discount_factor : float 

149 Discount factor, defaults to 1.0. 

150 """ 

151 strikes = np.asarray(strikes, dtype=np.float64) 

152 ivs = np.asarray(ivs, dtype=np.float64) 

153 atm_idx = int(np.argmin(np.abs(strikes - forward))) 

154 sigma_atm = float(ivs[atm_idx]) 

155 return cls( 

156 x=strikes, 

157 y_bid=ivs, 

158 y_ask=ivs.copy(), 

159 x_coord=XCoord.FixedStrike, 

160 y_coord=YCoord.Volatility, 

161 metadata=SmileMetadata( 

162 forward=forward, 

163 discount_factor=discount_factor, 

164 expiry=expiry, 

165 sigma_atm=sigma_atm, 

166 ), 

167 ) 

168 

169 def plot(self, *, title: str = "Smile Data") -> matplotlib.figure.Figure: 

170 """Plot bid/ask Y-values as error bars vs X. 

171 

172 Axis labels are derived from coordinate names. 

173 """ 

174 from qsmile.core.plot import plot_bid_ask 

175 

176 return plot_bid_ask( 

177 self.x, 

178 self.y_mid, 

179 self.y_bid, 

180 self.y_ask, 

181 xlabel=self.x_coord.name, 

182 ylabel=self.y_coord.name, 

183 title=title, 

184 )