Coverage for src / qsmile / data / strikes.py: 98%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 22:47 +0000

1"""Strike-indexed columnar data with hierarchical MultiIndex columns.""" 

2 

3from __future__ import annotations 

4 

5import numpy as np 

6import pandas as pd 

7from numpy.typing import NDArray 

8 

9 

10class StrikeArray: 

11 """A mutable collection of named columns indexed by strike price. 

12 

13 Columns are stored in a ``pd.DataFrame`` with a two-level ``MultiIndex`` 

14 on columns (level-0 = category, level-1 = field). Callers address 

15 columns directly via ``tuple[str, str]`` keys such as ``("call", "bid")``. 

16 

17 When a new column is added whose strike index differs from the current 

18 global index, all columns are reindexed to the sorted union of strikes. 

19 """ 

20 

21 __slots__ = ("_df",) 

22 

23 def __init__(self) -> None: 

24 """Create an empty StrikeArray.""" 

25 idx = pd.Index([], dtype=np.float64, name="strike") 

26 cols = pd.MultiIndex.from_tuples([], names=["category", "field"]) 

27 self._df: pd.DataFrame = pd.DataFrame(index=idx, columns=cols, dtype=np.float64) 

28 

29 # ── setters ─────────────────────────────────────────────────── 

30 

31 def set(self, key: tuple[str, str], series: pd.Series) -> None: 

32 """Add or replace a column, updating the global strike index.""" 

33 idx = series.index.astype(np.float64) 

34 vals = series.values.astype(np.float64) 

35 

36 if len(idx) > 0 and idx.has_duplicates: 

37 msg = "strikes must not contain duplicates" 

38 raise ValueError(msg) 

39 

40 # Sort by strike 

41 order = np.argsort(idx) 

42 sorted_idx = pd.Index(idx[order], dtype=np.float64, name="strike") 

43 sorted_vals = vals[order] 

44 

45 if len(self._df.index) == 0: 

46 new_index = sorted_idx 

47 else: 

48 new_index = self._df.index.union(sorted_idx).astype(np.float64) 

49 new_index.name = "strike" 

50 

51 # Reindex existing columns if the index changed 

52 if not self._df.index.equals(new_index): 

53 self._df = self._df.reindex(new_index) 

54 

55 # Create a series aligned to the new index 

56 col_series = pd.Series(sorted_vals, index=sorted_idx, dtype=np.float64) 

57 col_aligned = col_series.reindex(new_index) 

58 

59 # Add as a hierarchical column 

60 self._df[key] = col_aligned.values 

61 

62 # ── read accessors ──────────────────────────────────────────── 

63 

64 @property 

65 def strikes(self) -> NDArray[np.float64]: 

66 """Common strike index as a sorted NDArray.""" 

67 return self._df.index.to_numpy(dtype=np.float64) 

68 

69 @property 

70 def columns(self) -> list[tuple[str, str]]: 

71 """Column keys in insertion order.""" 

72 return list(self._df.columns) 

73 

74 def values(self, key: tuple[str, str]) -> NDArray[np.float64]: 

75 """Get column values as an NDArray. Raises KeyError if absent.""" 

76 if key not in self._df.columns: 

77 raise KeyError(key) 

78 return self._df[key].to_numpy(dtype=np.float64) 

79 

80 def get_values(self, key: tuple[str, str]) -> NDArray[np.float64] | None: 

81 """Get column values as an NDArray, or None if absent.""" 

82 if key not in self._df.columns: 

83 return None 

84 return self._df[key].to_numpy(dtype=np.float64) 

85 

86 def has(self, key: tuple[str, str]) -> bool: 

87 """Check whether a column exists.""" 

88 return key in self._df.columns 

89 

90 def __len__(self) -> int: 

91 """Return the number of strikes.""" 

92 return len(self._df.index) 

93 

94 # ── operations ──────────────────────────────────────────────── 

95 

96 def filter(self, mask: NDArray[np.bool_]) -> StrikeArray: 

97 """Apply a boolean mask to all columns, returning a new StrikeArray.""" 

98 sa = StrikeArray() 

99 filtered = self._df.iloc[mask].copy() 

100 filtered.index.name = "strike" 

101 # Ensure columns retain MultiIndex 

102 if not isinstance(filtered.columns, pd.MultiIndex): 

103 filtered.columns = self._df.columns 

104 sa._df = filtered 

105 return sa 

106 

107 def to_dataframe(self) -> pd.DataFrame: 

108 """Return a copy of the internal DataFrame with hierarchical columns.""" 

109 return self._df.copy()