Coverage for src / qsmile / core / maps.py: 86%
90 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
1"""Coordinate transform maps and composition."""
3from __future__ import annotations
5from collections.abc import Callable
6from typing import TYPE_CHECKING
8import numpy as np
9from numpy.typing import NDArray
11from qsmile.core.coords import XCoord, YCoord
13if TYPE_CHECKING:
14 from qsmile.data.meta import SmileMetadata
16# Type alias for map functions
17# X-maps: (x_array, metadata) -> x_array
18# Y-maps: (y_array, x_array, metadata) -> y_array
19XMapFn = Callable[["NDArray[np.float64]", "SmileMetadata"], "NDArray[np.float64]"]
20YMapFn = Callable[
21 ["NDArray[np.float64]", "NDArray[np.float64]", "SmileMetadata"],
22 "NDArray[np.float64]",
23]
25# Ordered ladders
26X_LADDER: list[XCoord] = [
27 XCoord.FixedStrike,
28 XCoord.MoneynessStrike,
29 XCoord.LogMoneynessStrike,
30 XCoord.StandardisedStrike,
31]
33Y_LADDER: list[YCoord] = [
34 YCoord.Price,
35 YCoord.Volatility,
36 YCoord.Variance,
37 YCoord.TotalVariance,
38]
41# --- X-map functions ---
44def _fixed_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
45 if meta.forward is None:
46 msg = "forward is required for FixedStrike to MoneynessStrike transform"
47 raise TypeError(msg)
48 return x / meta.forward
51def _moneyness_to_fixed(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
52 if meta.forward is None:
53 msg = "forward is required for MoneynessStrike to FixedStrike transform"
54 raise TypeError(msg)
55 return x * meta.forward
58def _moneyness_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
59 return np.log(x)
62def _log_moneyness_to_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
63 return np.exp(x)
66def _log_moneyness_to_standardised(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
67 if meta.sigma_atm is None:
68 msg = "sigma_atm is required for StandardisedStrike transforms"
69 raise ValueError(msg)
70 return x / (meta.sigma_atm * np.sqrt(meta.texpiry))
73def _standardised_to_log_moneyness(x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
74 if meta.sigma_atm is None:
75 msg = "sigma_atm is required for StandardisedStrike transforms"
76 raise ValueError(msg)
77 return x * meta.sigma_atm * np.sqrt(meta.texpiry)
80# --- Y-map functions ---
83def _vol_to_variance(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
84 return y**2
87def _variance_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
88 return np.sqrt(np.maximum(y, 0.0))
91def _variance_to_total_variance(
92 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata
93) -> NDArray[np.float64]:
94 return y * meta.texpiry
97def _total_variance_to_variance(
98 y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata
99) -> NDArray[np.float64]:
100 return y / meta.texpiry
103def _vol_to_price(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
104 """Convert implied volatilities to Black76 call prices.
106 x must be in FixedStrike coordinates (absolute strikes).
107 """
108 from qsmile.core.black76 import black76_call
110 if meta.forward is None or meta.discount_factor is None:
111 msg = "forward and discount_factor are required for vol-to-price transform"
112 raise TypeError(msg)
113 return np.asarray(
114 black76_call(meta.forward, x, meta.discount_factor, y, meta.texpiry),
115 dtype=np.float64,
116 )
119def _price_to_vol(y: NDArray[np.float64], x: NDArray[np.float64], meta: SmileMetadata) -> NDArray[np.float64]:
120 """Convert Black76 call prices to implied volatilities.
122 x must be in FixedStrike coordinates (absolute strikes).
123 """
124 from qsmile.core.black76 import black76_implied_vol
126 if meta.forward is None or meta.discount_factor is None:
127 msg = "forward and discount_factor are required for price-to-vol transform"
128 raise TypeError(msg)
129 n = len(y)
130 result = np.empty(n, dtype=np.float64)
131 for i in range(n):
132 result[i] = black76_implied_vol(
133 float(y[i]),
134 meta.forward,
135 float(x[i]),
136 meta.discount_factor,
137 meta.texpiry,
138 is_call=True,
139 )
140 return result
143# --- Registries ---
145# X-map registry: (source, target) -> map function
146_X_MAPS: dict[tuple[XCoord, XCoord], XMapFn] = {
147 (XCoord.FixedStrike, XCoord.MoneynessStrike): _fixed_to_moneyness,
148 (XCoord.MoneynessStrike, XCoord.FixedStrike): _moneyness_to_fixed,
149 (XCoord.MoneynessStrike, XCoord.LogMoneynessStrike): _moneyness_to_log_moneyness,
150 (XCoord.LogMoneynessStrike, XCoord.MoneynessStrike): _log_moneyness_to_moneyness,
151 (XCoord.LogMoneynessStrike, XCoord.StandardisedStrike): _log_moneyness_to_standardised,
152 (XCoord.StandardisedStrike, XCoord.LogMoneynessStrike): _standardised_to_log_moneyness,
153}
155# Y-map registry: (source, target) -> map function
156_Y_MAPS: dict[tuple[YCoord, YCoord], YMapFn] = {
157 (YCoord.Price, YCoord.Volatility): _price_to_vol,
158 (YCoord.Volatility, YCoord.Price): _vol_to_price,
159 (YCoord.Volatility, YCoord.Variance): _vol_to_variance,
160 (YCoord.Variance, YCoord.Volatility): _variance_to_vol,
161 (YCoord.Variance, YCoord.TotalVariance): _variance_to_total_variance,
162 (YCoord.TotalVariance, YCoord.Variance): _total_variance_to_variance,
163}
166def _ladder_path(ladder: list, source: object, target: object) -> list:
167 """Return the sequence of ladder steps from source to target."""
168 src_idx = ladder.index(source)
169 tgt_idx = ladder.index(target)
170 if src_idx == tgt_idx:
171 return []
172 step = 1 if tgt_idx > src_idx else -1
173 return [(ladder[i], ladder[i + step]) for i in range(src_idx, tgt_idx, step)]
176def compose_x_maps(
177 source: XCoord,
178 target: XCoord,
179) -> list[tuple[XCoord, XCoord, XMapFn]]:
180 """Return the chain of X-maps needed to go from source to target."""
181 path = _ladder_path(X_LADDER, source, target)
182 return [(s, t, _X_MAPS[(s, t)]) for s, t in path]
185def compose_y_maps(
186 source: YCoord,
187 target: YCoord,
188) -> list[tuple[YCoord, YCoord, YMapFn]]:
189 """Return the chain of Y-maps needed to go from source to target."""
190 path = _ladder_path(Y_LADDER, source, target)
191 return [(s, t, _Y_MAPS[(s, t)]) for s, t in path]
194def apply_x_chain(
195 x: NDArray[np.float64],
196 chain: list[tuple[XCoord, XCoord, XMapFn]],
197 meta: SmileMetadata,
198) -> NDArray[np.float64]:
199 """Apply a chain of X-maps sequentially."""
200 result = x
201 for _s, _t, fn in chain:
202 result = fn(result, meta)
203 return result
206def apply_y_chain(
207 y: NDArray[np.float64],
208 x: NDArray[np.float64],
209 chain: list[tuple[YCoord, YCoord, YMapFn]],
210 meta: SmileMetadata,
211 x_coord: XCoord,
212 target_x: XCoord,
213) -> NDArray[np.float64]:
214 """Apply a chain of Y-maps sequentially.
216 For the Price↔Volatility step, X must be in FixedStrike.
217 If needed, temporarily converts X to FixedStrike and back.
218 """
219 result = y
220 current_x = x.copy()
221 current_x_coord = x_coord
223 for s, t, fn in chain:
224 needs_fixed = (s == YCoord.Price and t == YCoord.Volatility) or (s == YCoord.Volatility and t == YCoord.Price)
225 if needs_fixed and current_x_coord != XCoord.FixedStrike:
226 # Convert X to FixedStrike
227 to_fixed = compose_x_maps(current_x_coord, XCoord.FixedStrike)
228 fixed_x = apply_x_chain(current_x, to_fixed, meta)
229 result = fn(result, fixed_x, meta)
230 # X stays in current_x_coord (unchanged for subsequent steps)
231 else:
232 result = fn(result, current_x, meta)
234 return result