Coverage for src / qsmile / data / strikes.py: 98%
56 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 22:47 +0000
1"""Strike-indexed columnar data with hierarchical MultiIndex columns."""
3from __future__ import annotations
5import numpy as np
6import pandas as pd
7from numpy.typing import NDArray
10class StrikeArray:
11 """A mutable collection of named columns indexed by strike price.
13 Columns are stored in a ``pd.DataFrame`` with a two-level ``MultiIndex``
14 on columns (level-0 = category, level-1 = field). Callers address
15 columns directly via ``tuple[str, str]`` keys such as ``("call", "bid")``.
17 When a new column is added whose strike index differs from the current
18 global index, all columns are reindexed to the sorted union of strikes.
19 """
21 __slots__ = ("_df",)
23 def __init__(self) -> None:
24 """Create an empty StrikeArray."""
25 idx = pd.Index([], dtype=np.float64, name="strike")
26 cols = pd.MultiIndex.from_tuples([], names=["category", "field"])
27 self._df: pd.DataFrame = pd.DataFrame(index=idx, columns=cols, dtype=np.float64)
29 # ── setters ───────────────────────────────────────────────────
31 def set(self, key: tuple[str, str], series: pd.Series) -> None:
32 """Add or replace a column, updating the global strike index."""
33 idx = series.index.astype(np.float64)
34 vals = series.values.astype(np.float64)
36 if len(idx) > 0 and idx.has_duplicates:
37 msg = "strikes must not contain duplicates"
38 raise ValueError(msg)
40 # Sort by strike
41 order = np.argsort(idx)
42 sorted_idx = pd.Index(idx[order], dtype=np.float64, name="strike")
43 sorted_vals = vals[order]
45 if len(self._df.index) == 0:
46 new_index = sorted_idx
47 else:
48 new_index = self._df.index.union(sorted_idx).astype(np.float64)
49 new_index.name = "strike"
51 # Reindex existing columns if the index changed
52 if not self._df.index.equals(new_index):
53 self._df = self._df.reindex(new_index)
55 # Create a series aligned to the new index
56 col_series = pd.Series(sorted_vals, index=sorted_idx, dtype=np.float64)
57 col_aligned = col_series.reindex(new_index)
59 # Add as a hierarchical column
60 self._df[key] = col_aligned.values
62 # ── read accessors ────────────────────────────────────────────
64 @property
65 def strikes(self) -> NDArray[np.float64]:
66 """Common strike index as a sorted NDArray."""
67 return self._df.index.to_numpy(dtype=np.float64)
69 @property
70 def columns(self) -> list[tuple[str, str]]:
71 """Column keys in insertion order."""
72 return list(self._df.columns)
74 def values(self, key: tuple[str, str]) -> NDArray[np.float64]:
75 """Get column values as an NDArray. Raises KeyError if absent."""
76 if key not in self._df.columns:
77 raise KeyError(key)
78 return self._df[key].to_numpy(dtype=np.float64)
80 def get_values(self, key: tuple[str, str]) -> NDArray[np.float64] | None:
81 """Get column values as an NDArray, or None if absent."""
82 if key not in self._df.columns:
83 return None
84 return self._df[key].to_numpy(dtype=np.float64)
86 def has(self, key: tuple[str, str]) -> bool:
87 """Check whether a column exists."""
88 return key in self._df.columns
90 def __len__(self) -> int:
91 """Return the number of strikes."""
92 return len(self._df.index)
94 # ── operations ────────────────────────────────────────────────
96 def filter(self, mask: NDArray[np.bool_]) -> StrikeArray:
97 """Apply a boolean mask to all columns, returning a new StrikeArray."""
98 sa = StrikeArray()
99 filtered = self._df.iloc[mask].copy()
100 filtered.index.name = "strike"
101 # Ensure columns retain MultiIndex
102 if not isinstance(filtered.columns, pd.MultiIndex):
103 filtered.columns = self._df.columns
104 sa._df = filtered
105 return sa
107 def to_dataframe(self) -> pd.DataFrame:
108 """Return a copy of the internal DataFrame with hierarchical columns."""
109 return self._df.copy()