Coverage for python / lsst / pipe / tasks / ssp / util.py: 11%

183 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 10:39 +0000

1import numpy as np 

2import astropy 

3from astropy.time import Time 

4import pandas as pd 

5import astropy.units as u 

6from astropy.coordinates import get_sun, angular_separation 

7import numpy.ma as ma 

8from astropy.constants import R_earth 

9from astropy.coordinates import ( 

10 EarthLocation, 

11 solar_system_ephemeris, 

12 get_body_barycentric_posvel, 

13) 

14from astroquery.mpc import MPC 

15from typing import Optional 

16import pyarrow as pa 

17import pyarrow.parquet as pq 

18from datetime import datetime 

19import pyarrow.compute as pc 

20 

21 

22def assoc_validate(dia, assoc): 

23 # verify coordinates and times match 

24 dia = dia[["ra", "dec", "midpointMjdTai"]].iloc[assoc["dia_index"].values] 

25 

26 # verify coordinates match 

27 obs = astropy.coordinates.SkyCoord(ra=dia["ra"].values, dec=dia["dec"].values, unit="deg") 

28 mpc = astropy.coordinates.SkyCoord(ra=assoc["mpc_ra"].values, dec=assoc["mpc_dec"].values, unit="deg") 

29 sep = obs.separation(mpc) 

30 

31 print("Separation diffeerence range (arcsec): ", sep.min().arcsec, sep.max().arcsec) 

32 assert sep.min().arcsec >= 0 

33 assert sep.max().arcsec <= 0.005 # FIXME: this should be further 

34 # tightened once we start submitting extra precision to the MPC 

35 

36 # verify times match 

37 t_utc = Time(dia["midpointMjdTai"].to_numpy(), format="mjd", scale="tai").utc 

38 midpoint_utc = pd.to_datetime(t_utc.to_datetime()) 

39 mpc_time = pd.to_datetime(assoc["mpc_obstime"]) 

40 delta_sec = (mpc_time - midpoint_utc).dt.total_seconds() 

41 delta_sec 

42 

43 print("Time diffeerence range (sec): ", delta_sec.min(), delta_sec.max()) 

44 

45 # FIXME: this was relaxed as USDF replica's obstime datatype is borked 

46 # and rounds (or truncates?) the timestamps to 1 second. E-mailed Dan S. 

47 # to get it fixed. 

48 # assert abs(delta_sec).max() < 0.01 

49 assert abs(delta_sec).max() < 0.51 

50 

51 print(f"All OK, {len(assoc):,} observations.") 

52 

53 

54def packed_ascii_to_uint64_le(mpc_packed): 

55 """ 

56 Convert a pandas string[pyarrow] column of ASCII strings (<= 8 bytes) 

57 to little-endian uint64 by left-padding with ASCII spaces to 8 chars. 

58 """ 

59 

60 # Step 1: Convert pandas Series → real pyarrow.StringArray 

61 arr = pa.array(mpc_packed, type=pa.string()) 

62 

63 # Step 2: Left-pad to length 8 with ASCII spaces (works on older PyArrow!) 

64 # ascii_lpad takes (string_array, target_length, pad_char) 

65 padded = pc.ascii_lpad(arr, 8, " ") # returns string array padded to 8 chars 

66 

67 # Step 3: Convert padded string → binary 

68 bin_arr = pc.cast(padded, pa.binary()) 

69 

70 # Step 4: Slice each to exactly 8 bytes 

71 fixed = pc.binary_slice(bin_arr, 0, 8) 

72 

73 # Step 5: Flatten chunks into a single contiguous array 

74 if isinstance(fixed, pa.ChunkedArray): 

75 fixed = fixed.combine_chunks() 

76 

77 # Step 6: Extract contiguous values buffer 

78 buf = fixed.buffers()[2] 

79 

80 # Step 7: Interpret every 8 bytes as a little-endian uint64 

81 return np.frombuffer(buf, dtype="<u8") 

82 

83 

84def solar_elongation_ndarray(ra_deg, dec_deg, t): 

85 """ 

86 Very fast computation of solar elongation (ICRS great-circle separation) 

87 using astropy.coordinates.angular_separation. 

88 

89 Parameters 

90 ---------- 

91 ra_deg : ndarray 

92 Target RA in degrees (ICRS). 

93 dec_deg : ndarray 

94 Target Dec in degrees (ICRS). 

95 t : astropy.time.Time 

96 Observation times. 

97 

98 Returns 

99 ------- 

100 elong_deg : ndarray 

101 Solar elongation in degrees. 

102 """ 

103 

104 # Get Sun coordinates 

105 # FIXME: This is sloooooww af. Should probably extract it in a few points, 

106 # then fit a spline or piecewise poly. 

107 sun = get_sun(t).icrs # cheap transformation to ICRS once per row 

108 

109 # Extract Sun RA/Dec arrays (radian floats) 

110 sun_ra = sun.ra.radian 

111 sun_dec = sun.dec.radian 

112 

113 # Convert input to radians 

114 ra = np.radians(ra_deg) 

115 dec = np.radians(dec_deg) 

116 

117 # Fast great-circle angular separation 

118 sep = angular_separation(ra, dec, sun_ra, sun_dec) 

119 

120 # Convert to degrees for return 

121 return np.degrees(sep) 

122 

123 

124def group_by(arrs, key, func, out=None, check_grouped=True): 

125 """ 

126 Group multiple NumPy arrays by arrs[0][key], assuming the key column 

127 is already grouped (equal keys are contiguous), but group blocks may 

128 appear in any order. 

129 

130 If out is provided: 

131 func(row_view, *subarrs) 

132 If out is None: 

133 func(*subarrs) 

134 

135 Parameters 

136 ---------- 

137 arrs : list/tuple of ndarray 

138 Equal-length arrays. arrs[0] contains the grouping key. 

139 key : str 

140 Column name in arrs[0] to group by. 

141 func : callable 

142 Called either as func(row, *subarrs) or func(*subarrs). 

143 out : ndarray or None 

144 Optional preallocated structured output array. 

145 check_grouped : bool 

146 If True, verify that key values are grouped contiguously. 

147 

148 Returns 

149 ------- 

150 ndarray or dict 

151 """ 

152 arr0 = arrs[0] 

153 keys = arr0[key] 

154 

155 # ---------- Grouped-contiguous check ---------- 

156 if check_grouped: 

157 seen = set() 

158 current = keys[0] 

159 seen.add(current) 

160 

161 for i in range(1, len(keys)): 

162 k = keys[i] 

163 if k != current: 

164 # Key changed 

165 if k in seen: 

166 raise ValueError( 

167 f"Key '{key}' is not properly grouped. " 

168 f"Value {k} reappears at index {i} " 

169 f"after a different key was encountered." 

170 ) 

171 seen.add(k) 

172 current = k 

173 

174 # ---------- Find true group boundaries ---------- 

175 unique_keys, idx_start, counts = np.unique(keys, return_index=True, return_counts=True) 

176 idx_end = idx_start + counts 

177 n_groups = len(unique_keys) 

178 

179 # ---------- Preallocated output path ---------- 

180 if out is not None: 

181 if len(out) < n_groups: 

182 raise ValueError(f"Out array too small: need {n_groups}, have {len(out)}") 

183 

184 for out_idx, (start, end) in enumerate(zip(idx_start, idx_end)): 

185 subarrs = tuple(a[start:end] for a in arrs) 

186 row = out[out_idx] # writable structured scalar 

187 func(row, *subarrs) 

188 if out_idx % 100 == 0: 

189 print(f"[{datetime.now().isoformat()}] count={out_idx}") 

190 

191 return out 

192 

193 # ---------- Dict output path ---------- 

194 results = {} 

195 for keyval, start, end in zip(unique_keys, idx_start, idx_end): 

196 subarrs = tuple(a[start:end] for a in arrs) 

197 results[keyval] = func(*subarrs) 

198 

199 return results 

200 

201 

202def values_grouped(a: np.ndarray) -> bool: 

203 """ 

204 Return True if each distinct value in 1D array `a` 

205 appears in a single contiguous block (all duplicates grouped). 

206 """ 

207 a = np.asarray(a) 

208 if a.ndim != 1: 

209 raise ValueError("a must be 1D") 

210 if a.size == 0: 

211 return True 

212 

213 # 1) True where a new group starts: first element, or value != previous 

214 group_starts = np.concatenate(([True], a[1:] != a[:-1])) 

215 

216 # 2) Values for each group (one per contiguous block) 

217 group_vals = a[group_starts] 

218 

219 # 3) Check that no value appears in more than one group 

220 # i.e., all group_vals are unique 

221 return np.unique(group_vals).size == group_vals.size 

222 

223 

224def earthlocation_from_obscode(obscode: str) -> EarthLocation: 

225 """ 

226 Convert an MPC observatory code (e.g. 'X05') to an EarthLocation, 

227 using MPC.get_observatory_codes() columns: 

228 Code, Longitude, cos, sin, Name. 

229 """ 

230 tbl = MPC.get_observatory_codes() 

231 row = tbl[tbl["Code"] == obscode] 

232 if len(row) != 1: 

233 raise ValueError(f"Unknown or ambiguous obscode {obscode!r}") 

234 row = row[0] 

235 

236 # Handle missing ground positions (spacecraft, etc.) 

237 if ma.is_masked(row["Longitude"]) or ma.is_masked(row["cos"]) or ma.is_masked(row["sin"]): 

238 raise ValueError(f"Obscode {obscode!r} has no ground position (spacecraft?)") 

239 

240 lon = (row["Longitude"] * u.deg).to(u.rad).value # radians 

241 rho_cosphi = float(row["cos"]) 

242 rho_sinphi = float(row["sin"]) 

243 

244 # Geocentric Cartesian coordinates in Earth radii 

245 x_er = rho_cosphi * np.cos(lon) 

246 y_er = rho_cosphi * np.sin(lon) 

247 z_er = rho_sinphi 

248 

249 # Convert Earth radii -> meters 

250 x = (x_er * R_earth).to(u.m) 

251 y = (y_er * R_earth).to(u.m) 

252 z = (z_er * R_earth).to(u.m) 

253 

254 return EarthLocation.from_geocentric(x, y, z) 

255 

256 

257def observatory_barycentric_posvel(obscode: str, obstime: Time): 

258 """ 

259 Barycentric (ICRS) position and velocity of an observatory given an 

260 MPC obscode, using JPL DE440 for the Earth ephemeris. 

261 

262 Returns 

263 ------- 

264 r_bary : Quantity, shape (3, ...) 

265 Barycentric position in AU. 

266 v_bary : Quantity, shape (3, ...) 

267 Barycentric velocity in AU/day. 

268 """ 

269 loc = earthlocation_from_obscode(obscode) 

270 

271 # Geocentric position & velocity of the site in GCRS (Earth center) 

272 obsgeoloc, obsgeovel = loc.get_gcrs_posvel(obstime) 

273 

274 # Earth barycentric pos/vel in ICRS using DE440 

275 with solar_system_ephemeris.set("de440"): 

276 earth_pos, earth_vel = get_body_barycentric_posvel("earth", obstime) 

277 

278 # ---- convert everything to SI ---- 

279 # Earth (ICRS, barycentric) 

280 r_earth_si = earth_pos.xyz.to(u.m) 

281 v_earth_si = earth_vel.xyz.to(u.m / u.s) 

282 

283 # Site (GCRS, geocentric) 

284 r_site_geo_si = getattr(obsgeoloc, "xyz", obsgeoloc).to(u.m) 

285 v_site_geo_si = getattr(obsgeovel, "xyz", obsgeovel).to(u.m / u.s) 

286 

287 # ---- barycentric site vectors in SI ---- 

288 r_site_bary_si = r_earth_si + r_site_geo_si 

289 v_site_bary_si = v_earth_si + v_site_geo_si 

290 

291 # ---- convert to AU, AU/day ---- 

292 r_site_bary = r_site_bary_si.to(u.au) 

293 v_site_bary = v_site_bary_si.to(u.au / u.day) 

294 

295 return r_site_bary, v_site_bary 

296 

297 

298# 

299# Serialization 

300# 

301 

302 

303def struct_to_parquet( 

304 arr: np.ndarray, 

305 path: str, 

306 *, 

307 chunk_size: Optional[int] = None, 

308 row_group_size: Optional[int] = None, 

309) -> None: 

310 """ 

311 Write a large NumPy structured array to a Parquet file using PyArrow. 

312 

313 Designed for dtypes like dia_dtype / orbit_dtype with up to ~1e8 rows. 

314 """ 

315 

316 if arr.dtype.names is None: 

317 raise TypeError("struct_to_parquet expects a structured NumPy array (dtype.names is None).") 

318 

319 n_rows = len(arr) 

320 if n_rows == 0: 

321 return 

322 

323 # Heuristic default chunk size 

324 if chunk_size is None: 

325 if n_rows <= 10_000_000: 

326 chunk_size = n_rows 

327 else: 

328 chunk_size = 1_000_000 

329 

330 def _numpy_to_arrow_array(col: np.ndarray, name: str) -> pa.Array: 

331 """ 

332 Convert a 1D NumPy column view to a PyArrow Array. 

333 

334 - Object / unicode / bytes → Arrow string(), with padding stripped 

335 for fixed-width 'S'/'U' dtypes. 

336 - Numeric types → Arrow infers type from NumPy. 

337 """ 

338 kind = col.dtype.kind 

339 

340 if kind == "O": 

341 # Already Python objects (str); Arrow will handle them fine. 

342 return pa.array(col, type=pa.string()) 

343 

344 if kind == "S": 

345 # Fixed-width bytes, padded with b"\\x00". 

346 # Decode + strip trailing NULs. 

347 # NOTE: adjust encoding if you know it's not ASCII/UTF-8. 

348 decoded = np.char.decode(col, "utf-8", errors="ignore") 

349 stripped = np.char.rstrip(decoded, "\x00") 

350 return pa.array(stripped, type=pa.string()) 

351 

352 if kind == "U": 

353 # Fixed-width unicode, padded with U+0000. 

354 stripped = np.char.rstrip(col, "\x00") 

355 return pa.array(stripped, type=pa.string()) 

356 

357 # Numeric / bool dtypes: let Arrow infer 

358 return pa.array(col) 

359 

360 def _chunk_to_table(chunk: np.ndarray) -> pa.Table: 

361 arrays = [] 

362 fields = [] 

363 

364 for name in chunk.dtype.names: 

365 col = chunk[name] 

366 arr_arrow = _numpy_to_arrow_array(col, name) 

367 arrays.append(arr_arrow) 

368 fields.append(pa.field(name, arr_arrow.type)) 

369 

370 schema = pa.schema(fields) 

371 return pa.Table.from_arrays(arrays, schema=schema) 

372 

373 writer: Optional[pq.ParquetWriter] = None 

374 try: 

375 for start in range(0, n_rows, chunk_size): 

376 end = min(start + chunk_size, n_rows) 

377 chunk = arr[start:end] 

378 table = _chunk_to_table(chunk) 

379 

380 if writer is None: 

381 writer = pq.ParquetWriter(path, table.schema) 

382 writer.write_table(table, row_group_size=row_group_size) 

383 finally: 

384 if writer is not None: 

385 writer.close() 

386 

387 

388# Jupiter's semimajor axis in AU (J2000-ish) 

389A_JUP = 5.2044 

390 

391 

392def tisserand_jupiter(a, e, inc_deg, a_j=A_JUP): 

393 """ 

394 Compute Tisserand parameter with respect to Jupiter. 

395 

396 Parameters 

397 ---------- 

398 a : float or ndarray 

399 Semimajor axis of the small body [AU]. 

400 e : float or ndarray 

401 Eccentricity. 

402 inc_deg : float or ndarray 

403 Inclination [degrees], typically to the ecliptic. 

404 a_j : float 

405 Semimajor axis of Jupiter [AU]. Default ~5.2044. 

406 

407 Returns 

408 ------- 

409 T_J : float or ndarray 

410 Tisserand parameter with respect to Jupiter. 

411 """ 

412 inc_rad = np.deg2rad(inc_deg) 

413 return (a_j / a) + 2.0 * np.cos(inc_rad) * np.sqrt((a / a_j) * (1.0 - e**2)) 

414 

415 

416def unpack(df, to_numpy=True): 

417 """ 

418 Return all DataFrame columns as a tuple. 

419 

420 Parameters 

421 ---------- 

422 df : pandas.DataFrame 

423 Input dataframe. 

424 to_numpy : bool, default True 

425 If True, return each column as a NumPy array. 

426 If False, return each column as a pandas Series. 

427 

428 Returns 

429 ------- 

430 tuple 

431 Tuple of columns in the original order. 

432 """ 

433 if to_numpy: 

434 return tuple(df[col].to_numpy() for col in df.columns) 

435 else: 

436 return tuple(df[col] for col in df.columns) 

437 

438 

439def argjoin(a, v): 

440 """ 

441 Perform an efficient inner join between two 1-D NumPy arrays, returning 

442 the index pairs that match by value. 

443 

444 Parameters 

445 ---------- 

446 a : ndarray 

447 The left-hand array to join on. Must be 1-dimensional. 

448 v : ndarray 

449 The right-hand array to join on. Must be 1-dimensional. 

450 

451 Returns 

452 ------- 

453 aidx : ndarray (int) 

454 Indices into `a` selecting the rows that participate in the join. 

455 vidx : ndarray (int) 

456 Indices into `v` selecting the corresponding matching rows. 

457 

458 After the join: 

459 a[aidx] == v[vidx] 

460 is guaranteed to be true for all elements. 

461 

462 Notes 

463 ----- 

464 This function implements a pure NumPy equivalent of an SQL-style 

465 INNER JOIN on the key columns `a` and `v`. 

466 

467 The algorithm: 

468 

469 1. Sort `a` to produce a permutation `i` so that `a[i]` is sorted. 

470 2. Use `np.searchsorted(a[i], v)` to find, for each element of `v`, 

471 the candidate matching location in the sorted array. 

472 3. Map these positions back to the coordinates of the original array `a` 

473 using the permutation `i`. 

474 4. Filter out non-matches (values in `v` not present in `a`). 

475 The remaining pairs form the inner join. 

476 

477 Complexity 

478 ---------- 

479 Sorting: O(len(a) log len(a)) 

480 Searching: O(len(v) log len(a)) 

481 Total: O(n log n) 

482 

483 This is optimal for join-like operations on unsorted arrays in NumPy. 

484 

485 Examples 

486 -------- 

487 >>> a = np.array(["b", "a", "c", "b"]) 

488 >>> v = np.array(["a", "b", "x", "b"]) 

489 

490 >>> aidx, vidx = argjoin(a, v) 

491 >>> a[aidx] 

492 array(['a', 'b', 'b']) 

493 >>> v[vidx] 

494 array(['a', 'b', 'b']) 

495 

496 """ 

497 # 1. Sort a, remembering the permutation 

498 i = np.argsort(a) 

499 ai = a[i] 

500 

501 # 2. Locate each element of v within the sorted array 

502 idx = np.searchsorted(ai, v) 

503 

504 # Clip to avoid out-of-range indices when v contains values > max(a) 

505 idx = np.clip(idx, 0, len(ai) - 1) 

506 

507 # 3. Map positions in sorted array back to original array indices 

508 aidx_candidate = i[idx] 

509 

510 # 4. Keep only true matches (this implements an INNER JOIN) 

511 mask = ai[idx] == v 

512 vidx = np.flatnonzero(mask) 

513 

514 # Final matched indices in a 

515 aidx = aidx_candidate[vidx] 

516 

517 assert np.all(a[aidx] == v[vidx]) 

518 return aidx, vidx