Coverage for src / qsmile / core / maps.py: 86%

90 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 22:47 +0000

1"""Coordinate transform maps and composition.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable 

6from typing import TYPE_CHECKING 

7 

8import numpy as np 

9from numpy.typing import NDArray 

10 

11from qsmile.core.coords import XCoord, YCoord 

12 

13if TYPE_CHECKING: 

14 from qsmile.data.meta import SmileMetadata 

15 

16# Type alias for map functions 

17# X-maps: (x_array, metadata) -> x_array 

18# Y-maps: (y_array, x_array, metadata) -> y_array 

19XMapFn = Callable[["NDArray[np.float64]", "SmileMetadata"], "NDArray[np.float64]"] 

20YMapFn = Callable[ 

21 ["NDArray[np.float64]", "NDArray[np.float64]", "SmileMetadata"], 

22 "NDArray[np.float64]", 

23] 

24 

25# Ordered ladders 

26X_LADDER: list[XCoord] = [ 

27 XCoord.FixedStrike, 

28 XCoord.MoneynessStrike, 

29 XCoord.LogMoneynessStrike, 

30 XCoord.StandardisedStrike, 

31] 

32 

33Y_LADDER: list[YCoord] = [ 

34 YCoord.Price, 

35 YCoord.Volatility, 

36 YCoord.Variance, 

37 YCoord.TotalVariance, 

38] 

39 

40 

41# --- X-map functions --- 

42 

43 

44def _fixed_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

45 if meta.forward is None: 

46 msg = "forward is required for FixedStrike to MoneynessStrike transform" 

47 raise TypeError(msg) 

48 return x / meta.forward 

49 

50 

51def _moneyness_to_fixed(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

52 if meta.forward is None: 

53 msg = "forward is required for MoneynessStrike to FixedStrike transform" 

54 raise TypeError(msg) 

55 return x * meta.forward 

56 

57 

58def _moneyness_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

59 return np.log(x) 

60 

61 

62def _log_moneyness_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

63 return np.exp(x) 

64 

65 

66def _log_moneyness_to_standardised(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

67 if meta.sigma_atm is None: 

68 msg = "sigma_atm is required for StandardisedStrike transforms" 

69 raise ValueError(msg) 

70 return x / (meta.sigma_atm * np.sqrt(meta.texpiry)) 

71 

72 

73def _standardised_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

74 if meta.sigma_atm is None: 

75 msg = "sigma_atm is required for StandardisedStrike transforms" 

76 raise ValueError(msg) 

77 return x * meta.sigma_atm * np.sqrt(meta.texpiry) 

78 

79 

80# --- Y-map functions --- 

81 

82 

83def _vol_to_variance(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

84 return y**2 

85 

86 

87def _variance_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

88 return np.sqrt(np.maximum(y, 0.0)) 

89 

90 

91def _variance_to_total_variance( 

92 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata 

93) -> NDArray[np.float64]: 

94 return y * meta.texpiry 

95 

96 

97def _total_variance_to_variance( 

98 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata 

99) -> NDArray[np.float64]: 

100 return y / meta.texpiry 

101 

102 

103def _vol_to_price(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

104 """Convert implied volatilities to Black76 call prices. 

105 

106 x must be in FixedStrike coordinates (absolute strikes). 

107 """ 

108 from qsmile.core.black76 import black76_call 

109 

110 if meta.forward is None or meta.discount_factor is None: 

111 msg = "forward and discount_factor are required for vol-to-price transform" 

112 raise TypeError(msg) 

113 return np.asarray( 

114 black76_call(meta.forward, x, meta.discount_factor, y, meta.texpiry), 

115 dtype=np.float64, 

116 ) 

117 

118 

119def _price_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

120 """Convert Black76 call prices to implied volatilities. 

121 

122 x must be in FixedStrike coordinates (absolute strikes). 

123 """ 

124 from qsmile.core.black76 import black76_implied_vol 

125 

126 if meta.forward is None or meta.discount_factor is None: 

127 msg = "forward and discount_factor are required for price-to-vol transform" 

128 raise TypeError(msg) 

129 n = len(y) 

130 result = np.empty(n, dtype=np.float64) 

131 for i in range(n): 

132 result[i] = black76_implied_vol( 

133 float(y[i]), 

134 meta.forward, 

135 float(x[i]), 

136 meta.discount_factor, 

137 meta.texpiry, 

138 is_call=True, 

139 ) 

140 return result 

141 

142 

143# --- Registries --- 

144 

145# X-map registry: (source, target) -> map function 

146_X_MAPS: dict[tuple[XCoord, XCoord], XMapFn] = { 

147 (XCoord.FixedStrike, XCoord.MoneynessStrike): _fixed_to_moneyness, 

148 (XCoord.MoneynessStrike, XCoord.FixedStrike): _moneyness_to_fixed, 

149 (XCoord.MoneynessStrike, XCoord.LogMoneynessStrike): _moneyness_to_log_moneyness, 

150 (XCoord.LogMoneynessStrike, XCoord.MoneynessStrike): _log_moneyness_to_moneyness, 

151 (XCoord.LogMoneynessStrike, XCoord.StandardisedStrike): _log_moneyness_to_standardised, 

152 (XCoord.StandardisedStrike, XCoord.LogMoneynessStrike): _standardised_to_log_moneyness, 

153} 

154 

155# Y-map registry: (source, target) -> map function 

156_Y_MAPS: dict[tuple[YCoord, YCoord], YMapFn] = { 

157 (YCoord.Price, YCoord.Volatility): _price_to_vol, 

158 (YCoord.Volatility, YCoord.Price): _vol_to_price, 

159 (YCoord.Volatility, YCoord.Variance): _vol_to_variance, 

160 (YCoord.Variance, YCoord.Volatility): _variance_to_vol, 

161 (YCoord.Variance, YCoord.TotalVariance): _variance_to_total_variance, 

162 (YCoord.TotalVariance, YCoord.Variance): _total_variance_to_variance, 

163} 

164 

165 

166def _ladder_path(ladder: list, source: object, target: object) -> list: 

167 """Return the sequence of ladder steps from source to target.""" 

168 src_idx = ladder.index(source) 

169 tgt_idx = ladder.index(target) 

170 if src_idx == tgt_idx: 

171 return [] 

172 step = 1 if tgt_idx > src_idx else -1 

173 return [(ladder[i], ladder[i + step]) for i in range(src_idx, tgt_idx, step)] 

174 

175 

176def compose_x_maps( 

177 source: XCoord, 

178 target: XCoord, 

179) -> list[tuple[XCoord, XCoord, XMapFn]]: 

180 """Return the chain of X-maps needed to go from source to target.""" 

181 path = _ladder_path(X_LADDER, source, target) 

182 return [(s, t, _X_MAPS[(s, t)]) for s, t in path] 

183 

184 

185def compose_y_maps( 

186 source: YCoord, 

187 target: YCoord, 

188) -> list[tuple[YCoord, YCoord, YMapFn]]: 

189 """Return the chain of Y-maps needed to go from source to target.""" 

190 path = _ladder_path(Y_LADDER, source, target) 

191 return [(s, t, _Y_MAPS[(s, t)]) for s, t in path] 

192 

193 

194def apply_x_chain( 

195 x: NDArray[np.float64], 

196 chain: list[tuple[XCoord, XCoord, XMapFn]], 

197 meta: SmileMetadata, 

198) -> NDArray[np.float64]: 

199 """Apply a chain of X-maps sequentially.""" 

200 result = x 

201 for _s, _t, fn in chain: 

202 result = fn(result, meta) 

203 return result 

204 

205 

206def apply_y_chain( 

207 y: NDArray[np.float64], 

208 x: NDArray[np.float64], 

209 chain: list[tuple[YCoord, YCoord, YMapFn]], 

210 meta: SmileMetadata, 

211 x_coord: XCoord, 

212 target_x: XCoord, 

213) -> NDArray[np.float64]: 

214 """Apply a chain of Y-maps sequentially. 

215 

216 For the Price↔Volatility step, X must be in FixedStrike. 

217 If needed, temporarily converts X to FixedStrike and back. 

218 """ 

219 result = y 

220 current_x = x.copy() 

221 current_x_coord = x_coord 

222 

223 for s, t, fn in chain: 

224 needs_fixed = (s == YCoord.Price and t == YCoord.Volatility) or (s == YCoord.Volatility and t == YCoord.Price) 

225 if needs_fixed and current_x_coord != XCoord.FixedStrike: 

226 # Convert X to FixedStrike 

227 to_fixed = compose_x_maps(current_x_coord, XCoord.FixedStrike) 

228 fixed_x = apply_x_chain(current_x, to_fixed, meta) 

229 result = fn(result, fixed_x, meta) 

230 # X stays in current_x_coord (unchanged for subsequent steps) 

231 else: 

232 result = fn(result, current_x, meta) 

233 

234 return result