Coverage for src / qsmile / core / maps.py: 94%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-04 21:47 +0000

1"""Coordinate transform maps and composition.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable 

6from typing import TYPE_CHECKING 

7 

8import numpy as np 

9from numpy.typing import NDArray 

10 

11from qsmile.core.coords import XCoord, YCoord 

12 

13if TYPE_CHECKING: 

14 from qsmile.data.meta import SmileMetadata 

15 

16# Type alias for map functions 

17# X-maps: (x_array, metadata) -> x_array 

18# Y-maps: (y_array, x_array, metadata) -> y_array 

19XMapFn = Callable[["NDArray[np.float64]", "SmileMetadata"], "NDArray[np.float64]"] 

20YMapFn = Callable[ 

21 ["NDArray[np.float64]", "NDArray[np.float64]", "SmileMetadata"], 

22 "NDArray[np.float64]", 

23] 

24 

25# Ordered ladders 

26X_LADDER: list[XCoord] = [ 

27 XCoord.FixedStrike, 

28 XCoord.MoneynessStrike, 

29 XCoord.LogMoneynessStrike, 

30 XCoord.StandardisedStrike, 

31] 

32 

33Y_LADDER: list[YCoord] = [ 

34 YCoord.Price, 

35 YCoord.Volatility, 

36 YCoord.Variance, 

37 YCoord.TotalVariance, 

38] 

39 

40 

41# --- X-map functions --- 

42 

43 

44def _fixed_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

45 return x / meta.forward 

46 

47 

48def _moneyness_to_fixed(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

49 return x * meta.forward 

50 

51 

52def _moneyness_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

53 return np.log(x) 

54 

55 

56def _log_moneyness_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

57 return np.exp(x) 

58 

59 

60def _log_moneyness_to_standardised(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

61 if meta.sigma_atm is None: 

62 msg = "sigma_atm is required for StandardisedStrike transforms" 

63 raise ValueError(msg) 

64 return x / (meta.sigma_atm * np.sqrt(meta.expiry)) 

65 

66 

67def _standardised_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

68 if meta.sigma_atm is None: 

69 msg = "sigma_atm is required for StandardisedStrike transforms" 

70 raise ValueError(msg) 

71 return x * meta.sigma_atm * np.sqrt(meta.expiry) 

72 

73 

74# --- Y-map functions --- 

75 

76 

77def _vol_to_variance(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

78 return y**2 

79 

80 

81def _variance_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

82 return np.sqrt(y) 

83 

84 

85def _variance_to_total_variance( 

86 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata 

87) -> NDArray[np.float64]: 

88 return y * meta.expiry 

89 

90 

91def _total_variance_to_variance( 

92 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata 

93) -> NDArray[np.float64]: 

94 return y / meta.expiry 

95 

96 

97def _vol_to_price(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

98 """Convert implied volatilities to Black76 call prices. 

99 

100 x must be in FixedStrike coordinates (absolute strikes). 

101 """ 

102 from qsmile.core.black76 import black76_call 

103 

104 return np.asarray( 

105 black76_call(meta.forward, x, meta.discount_factor, y, meta.expiry), 

106 dtype=np.float64, 

107 ) 

108 

109 

110def _price_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]: 

111 """Convert Black76 call prices to implied volatilities. 

112 

113 x must be in FixedStrike coordinates (absolute strikes). 

114 """ 

115 from qsmile.core.black76 import black76_implied_vol 

116 

117 n = len(y) 

118 result = np.empty(n, dtype=np.float64) 

119 for i in range(n): 

120 result[i] = black76_implied_vol( 

121 float(y[i]), 

122 meta.forward, 

123 float(x[i]), 

124 meta.discount_factor, 

125 meta.expiry, 

126 is_call=True, 

127 ) 

128 return result 

129 

130 

131# --- Registries --- 

132 

133# X-map registry: (source, target) -> map function 

134_X_MAPS: dict[tuple[XCoord, XCoord], XMapFn] = { 

135 (XCoord.FixedStrike, XCoord.MoneynessStrike): _fixed_to_moneyness, 

136 (XCoord.MoneynessStrike, XCoord.FixedStrike): _moneyness_to_fixed, 

137 (XCoord.MoneynessStrike, XCoord.LogMoneynessStrike): _moneyness_to_log_moneyness, 

138 (XCoord.LogMoneynessStrike, XCoord.MoneynessStrike): _log_moneyness_to_moneyness, 

139 (XCoord.LogMoneynessStrike, XCoord.StandardisedStrike): _log_moneyness_to_standardised, 

140 (XCoord.StandardisedStrike, XCoord.LogMoneynessStrike): _standardised_to_log_moneyness, 

141} 

142 

143# Y-map registry: (source, target) -> map function 

144_Y_MAPS: dict[tuple[YCoord, YCoord], YMapFn] = { 

145 (YCoord.Price, YCoord.Volatility): _price_to_vol, 

146 (YCoord.Volatility, YCoord.Price): _vol_to_price, 

147 (YCoord.Volatility, YCoord.Variance): _vol_to_variance, 

148 (YCoord.Variance, YCoord.Volatility): _variance_to_vol, 

149 (YCoord.Variance, YCoord.TotalVariance): _variance_to_total_variance, 

150 (YCoord.TotalVariance, YCoord.Variance): _total_variance_to_variance, 

151} 

152 

153 

154def _ladder_path(ladder: list, source: object, target: object) -> list: 

155 """Return the sequence of ladder steps from source to target.""" 

156 src_idx = ladder.index(source) 

157 tgt_idx = ladder.index(target) 

158 if src_idx == tgt_idx: 

159 return [] 

160 step = 1 if tgt_idx > src_idx else -1 

161 return [(ladder[i], ladder[i + step]) for i in range(src_idx, tgt_idx, step)] 

162 

163 

164def compose_x_maps( 

165 source: XCoord, 

166 target: XCoord, 

167) -> list[tuple[XCoord, XCoord, XMapFn]]: 

168 """Return the chain of X-maps needed to go from source to target.""" 

169 path = _ladder_path(X_LADDER, source, target) 

170 return [(s, t, _X_MAPS[(s, t)]) for s, t in path] 

171 

172 

173def compose_y_maps( 

174 source: YCoord, 

175 target: YCoord, 

176) -> list[tuple[YCoord, YCoord, YMapFn]]: 

177 """Return the chain of Y-maps needed to go from source to target.""" 

178 path = _ladder_path(Y_LADDER, source, target) 

179 return [(s, t, _Y_MAPS[(s, t)]) for s, t in path] 

180 

181 

182def apply_x_chain( 

183 x: NDArray[np.float64], 

184 chain: list[tuple[XCoord, XCoord, XMapFn]], 

185 meta: SmileMetadata, 

186) -> NDArray[np.float64]: 

187 """Apply a chain of X-maps sequentially.""" 

188 result = x 

189 for _s, _t, fn in chain: 

190 result = fn(result, meta) 

191 return result 

192 

193 

194def apply_y_chain( 

195 y: NDArray[np.float64], 

196 x: NDArray[np.float64], 

197 chain: list[tuple[YCoord, YCoord, YMapFn]], 

198 meta: SmileMetadata, 

199 x_coord: XCoord, 

200 target_x: XCoord, 

201) -> NDArray[np.float64]: 

202 """Apply a chain of Y-maps sequentially. 

203 

204 For the Price↔Volatility step, X must be in FixedStrike. 

205 If needed, temporarily converts X to FixedStrike and back. 

206 """ 

207 result = y 

208 current_x = x.copy() 

209 current_x_coord = x_coord 

210 

211 for s, t, fn in chain: 

212 needs_fixed = (s == YCoord.Price and t == YCoord.Volatility) or (s == YCoord.Volatility and t == YCoord.Price) 

213 if needs_fixed and current_x_coord != XCoord.FixedStrike: 

214 # Convert X to FixedStrike 

215 to_fixed = compose_x_maps(current_x_coord, XCoord.FixedStrike) 

216 fixed_x = apply_x_chain(current_x, to_fixed, meta) 

217 result = fn(result, fixed_x, meta) 

218 # X stays in current_x_coord (unchanged for subsequent steps) 

219 else: 

220 result = fn(result, current_x, meta) 

221 

222 return result