Coverage for src / qsmile / core / maps.py: 94%
78 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-04 21:47 +0000
1"""Coordinate transform maps and composition."""
3from __future__ import annotations
5from collections.abc import Callable
6from typing import TYPE_CHECKING
8import numpy as np
9from numpy.typing import NDArray
11from qsmile.core.coords import XCoord, YCoord
13if TYPE_CHECKING:
14 from qsmile.data.meta import SmileMetadata
16# Type alias for map functions
17# X-maps: (x_array, metadata) -> x_array
18# Y-maps: (y_array, x_array, metadata) -> y_array
19XMapFn = Callable[["NDArray[np.float64]", "SmileMetadata"], "NDArray[np.float64]"]
20YMapFn = Callable[
21 ["NDArray[np.float64]", "NDArray[np.float64]", "SmileMetadata"],
22 "NDArray[np.float64]",
23]
25# Ordered ladders
26X_LADDER: list[XCoord] = [
27 XCoord.FixedStrike,
28 XCoord.MoneynessStrike,
29 XCoord.LogMoneynessStrike,
30 XCoord.StandardisedStrike,
31]
33Y_LADDER: list[YCoord] = [
34 YCoord.Price,
35 YCoord.Volatility,
36 YCoord.Variance,
37 YCoord.TotalVariance,
38]
41# --- X-map functions ---
44def _fixed_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
45 return x / meta.forward
48def _moneyness_to_fixed(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
49 return x * meta.forward
52def _moneyness_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
53 return np.log(x)
56def _log_moneyness_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
57 return np.exp(x)
60def _log_moneyness_to_standardised(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
61 if meta.sigma_atm is None:
62 msg = "sigma_atm is required for StandardisedStrike transforms"
63 raise ValueError(msg)
64 return x / (meta.sigma_atm * np.sqrt(meta.expiry))
67def _standardised_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
68 if meta.sigma_atm is None:
69 msg = "sigma_atm is required for StandardisedStrike transforms"
70 raise ValueError(msg)
71 return x * meta.sigma_atm * np.sqrt(meta.expiry)
74# --- Y-map functions ---
77def _vol_to_variance(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
78 return y**2
81def _variance_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
82 return np.sqrt(y)
85def _variance_to_total_variance(
86 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata
87) -> NDArray[np.float64]:
88 return y * meta.expiry
91def _total_variance_to_variance(
92 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata
93) -> NDArray[np.float64]:
94 return y / meta.expiry
97def _vol_to_price(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
98 """Convert implied volatilities to Black76 call prices.
100 x must be in FixedStrike coordinates (absolute strikes).
101 """
102 from qsmile.core.black76 import black76_call
104 return np.asarray(
105 black76_call(meta.forward, x, meta.discount_factor, y, meta.expiry),
106 dtype=np.float64,
107 )
110def _price_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
111 """Convert Black76 call prices to implied volatilities.
113 x must be in FixedStrike coordinates (absolute strikes).
114 """
115 from qsmile.core.black76 import black76_implied_vol
117 n = len(y)
118 result = np.empty(n, dtype=np.float64)
119 for i in range(n):
120 result[i] = black76_implied_vol(
121 float(y[i]),
122 meta.forward,
123 float(x[i]),
124 meta.discount_factor,
125 meta.expiry,
126 is_call=True,
127 )
128 return result
131# --- Registries ---
133# X-map registry: (source, target) -> map function
134_X_MAPS: dict[tuple[XCoord, XCoord], XMapFn] = {
135 (XCoord.FixedStrike, XCoord.MoneynessStrike): _fixed_to_moneyness,
136 (XCoord.MoneynessStrike, XCoord.FixedStrike): _moneyness_to_fixed,
137 (XCoord.MoneynessStrike, XCoord.LogMoneynessStrike): _moneyness_to_log_moneyness,
138 (XCoord.LogMoneynessStrike, XCoord.MoneynessStrike): _log_moneyness_to_moneyness,
139 (XCoord.LogMoneynessStrike, XCoord.StandardisedStrike): _log_moneyness_to_standardised,
140 (XCoord.StandardisedStrike, XCoord.LogMoneynessStrike): _standardised_to_log_moneyness,
141}
143# Y-map registry: (source, target) -> map function
144_Y_MAPS: dict[tuple[YCoord, YCoord], YMapFn] = {
145 (YCoord.Price, YCoord.Volatility): _price_to_vol,
146 (YCoord.Volatility, YCoord.Price): _vol_to_price,
147 (YCoord.Volatility, YCoord.Variance): _vol_to_variance,
148 (YCoord.Variance, YCoord.Volatility): _variance_to_vol,
149 (YCoord.Variance, YCoord.TotalVariance): _variance_to_total_variance,
150 (YCoord.TotalVariance, YCoord.Variance): _total_variance_to_variance,
151}
154def _ladder_path(ladder: list, source: object, target: object) -> list:
155 """Return the sequence of ladder steps from source to target."""
156 src_idx = ladder.index(source)
157 tgt_idx = ladder.index(target)
158 if src_idx == tgt_idx:
159 return []
160 step = 1 if tgt_idx > src_idx else -1
161 return [(ladder[i], ladder[i + step]) for i in range(src_idx, tgt_idx, step)]
164def compose_x_maps(
165 source: XCoord,
166 target: XCoord,
167) -> list[tuple[XCoord, XCoord, XMapFn]]:
168 """Return the chain of X-maps needed to go from source to target."""
169 path = _ladder_path(X_LADDER, source, target)
170 return [(s, t, _X_MAPS[(s, t)]) for s, t in path]
173def compose_y_maps(
174 source: YCoord,
175 target: YCoord,
176) -> list[tuple[YCoord, YCoord, YMapFn]]:
177 """Return the chain of Y-maps needed to go from source to target."""
178 path = _ladder_path(Y_LADDER, source, target)
179 return [(s, t, _Y_MAPS[(s, t)]) for s, t in path]
182def apply_x_chain(
183 x: NDArray[np.float64],
184 chain: list[tuple[XCoord, XCoord, XMapFn]],
185 meta: SmileMetadata,
186) -> NDArray[np.float64]:
187 """Apply a chain of X-maps sequentially."""
188 result = x
189 for _s, _t, fn in chain:
190 result = fn(result, meta)
191 return result
194def apply_y_chain(
195 y: NDArray[np.float64],
196 x: NDArray[np.float64],
197 chain: list[tuple[YCoord, YCoord, YMapFn]],
198 meta: SmileMetadata,
199 x_coord: XCoord,
200 target_x: XCoord,
201) -> NDArray[np.float64]:
202 """Apply a chain of Y-maps sequentially.
204 For the Price↔Volatility step, X must be in FixedStrike.
205 If needed, temporarily converts X to FixedStrike and back.
206 """
207 result = y
208 current_x = x.copy()
209 current_x_coord = x_coord
211 for s, t, fn in chain:
212 needs_fixed = (s == YCoord.Price and t == YCoord.Volatility) or (s == YCoord.Volatility and t == YCoord.Price)
213 if needs_fixed and current_x_coord != XCoord.FixedStrike:
214 # Convert X to FixedStrike
215 to_fixed = compose_x_maps(current_x_coord, XCoord.FixedStrike)
216 fixed_x = apply_x_chain(current_x, to_fixed, meta)
217 result = fn(result, fixed_x, meta)
218 # X stays in current_x_coord (unchanged for subsequent steps)
219 else:
220 result = fn(result, current_x, meta)
222 return result