Coverage for python / lsst / pipe / tasks / ssp / util.py: 11%
183 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:05 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:05 +0000
1import numpy as np
2import astropy
3from astropy.time import Time
4import pandas as pd
5import astropy.units as u
6from astropy.coordinates import get_sun, angular_separation
7import numpy.ma as ma
8from astropy.constants import R_earth
9from astropy.coordinates import (
10 EarthLocation,
11 solar_system_ephemeris,
12 get_body_barycentric_posvel,
13)
14from astroquery.mpc import MPC
15from typing import Optional
16import pyarrow as pa
17import pyarrow.parquet as pq
18from datetime import datetime
19import pyarrow.compute as pc
22def assoc_validate(dia, assoc):
23 # verify coordinates and times match
24 dia = dia[["ra", "dec", "midpointMjdTai"]].iloc[assoc["dia_index"].values]
26 # verify coordinates match
27 obs = astropy.coordinates.SkyCoord(ra=dia["ra"].values, dec=dia["dec"].values, unit="deg")
28 mpc = astropy.coordinates.SkyCoord(ra=assoc["mpc_ra"].values, dec=assoc["mpc_dec"].values, unit="deg")
29 sep = obs.separation(mpc)
31 print("Separation diffeerence range (arcsec): ", sep.min().arcsec, sep.max().arcsec)
32 assert sep.min().arcsec >= 0
33 assert sep.max().arcsec <= 0.005 # FIXME: this should be further
34 # tightened once we start submitting extra precision to the MPC
36 # verify times match
37 t_utc = Time(dia["midpointMjdTai"].to_numpy(), format="mjd", scale="tai").utc
38 midpoint_utc = pd.to_datetime(t_utc.to_datetime())
39 mpc_time = pd.to_datetime(assoc["mpc_obstime"])
40 delta_sec = (mpc_time - midpoint_utc).dt.total_seconds()
41 delta_sec
43 print("Time diffeerence range (sec): ", delta_sec.min(), delta_sec.max())
45 # FIXME: this was relaxed as USDF replica's obstime datatype is borked
46 # and rounds (or truncates?) the timestamps to 1 second. E-mailed Dan S.
47 # to get it fixed.
48 # assert abs(delta_sec).max() < 0.01
49 assert abs(delta_sec).max() < 0.51
51 print(f"All OK, {len(assoc):,} observations.")
54def packed_ascii_to_uint64_le(mpc_packed):
55 """
56 Convert a pandas string[pyarrow] column of ASCII strings (<= 8 bytes)
57 to little-endian uint64 by left-padding with ASCII spaces to 8 chars.
58 """
60 # Step 1: Convert pandas Series → real pyarrow.StringArray
61 arr = pa.array(mpc_packed, type=pa.string())
63 # Step 2: Left-pad to length 8 with ASCII spaces (works on older PyArrow!)
64 # ascii_lpad takes (string_array, target_length, pad_char)
65 padded = pc.ascii_lpad(arr, 8, " ") # returns string array padded to 8 chars
67 # Step 3: Convert padded string → binary
68 bin_arr = pc.cast(padded, pa.binary())
70 # Step 4: Slice each to exactly 8 bytes
71 fixed = pc.binary_slice(bin_arr, 0, 8)
73 # Step 5: Flatten chunks into a single contiguous array
74 if isinstance(fixed, pa.ChunkedArray):
75 fixed = fixed.combine_chunks()
77 # Step 6: Extract contiguous values buffer
78 buf = fixed.buffers()[2]
80 # Step 7: Interpret every 8 bytes as a little-endian uint64
81 return np.frombuffer(buf, dtype="<u8")
84def solar_elongation_ndarray(ra_deg, dec_deg, t):
85 """
86 Very fast computation of solar elongation (ICRS great-circle separation)
87 using astropy.coordinates.angular_separation.
89 Parameters
90 ----------
91 ra_deg : ndarray
92 Target RA in degrees (ICRS).
93 dec_deg : ndarray
94 Target Dec in degrees (ICRS).
95 t : astropy.time.Time
96 Observation times.
98 Returns
99 -------
100 elong_deg : ndarray
101 Solar elongation in degrees.
102 """
104 # Get Sun coordinates
105 # FIXME: This is sloooooww af. Should probably extract it in a few points,
106 # then fit a spline or piecewise poly.
107 sun = get_sun(t).icrs # cheap transformation to ICRS once per row
109 # Extract Sun RA/Dec arrays (radian floats)
110 sun_ra = sun.ra.radian
111 sun_dec = sun.dec.radian
113 # Convert input to radians
114 ra = np.radians(ra_deg)
115 dec = np.radians(dec_deg)
117 # Fast great-circle angular separation
118 sep = angular_separation(ra, dec, sun_ra, sun_dec)
120 # Convert to degrees for return
121 return np.degrees(sep)
124def group_by(arrs, key, func, out=None, check_grouped=True):
125 """
126 Group multiple NumPy arrays by arrs[0][key], assuming the key column
127 is already grouped (equal keys are contiguous), but group blocks may
128 appear in any order.
130 If out is provided:
131 func(row_view, *subarrs)
132 If out is None:
133 func(*subarrs)
135 Parameters
136 ----------
137 arrs : list/tuple of ndarray
138 Equal-length arrays. arrs[0] contains the grouping key.
139 key : str
140 Column name in arrs[0] to group by.
141 func : callable
142 Called either as func(row, *subarrs) or func(*subarrs).
143 out : ndarray or None
144 Optional preallocated structured output array.
145 check_grouped : bool
146 If True, verify that key values are grouped contiguously.
148 Returns
149 -------
150 ndarray or dict
151 """
152 arr0 = arrs[0]
153 keys = arr0[key]
155 # ---------- Grouped-contiguous check ----------
156 if check_grouped:
157 seen = set()
158 current = keys[0]
159 seen.add(current)
161 for i in range(1, len(keys)):
162 k = keys[i]
163 if k != current:
164 # Key changed
165 if k in seen:
166 raise ValueError(
167 f"Key '{key}' is not properly grouped. "
168 f"Value {k} reappears at index {i} "
169 f"after a different key was encountered."
170 )
171 seen.add(k)
172 current = k
174 # ---------- Find true group boundaries ----------
175 unique_keys, idx_start, counts = np.unique(keys, return_index=True, return_counts=True)
176 idx_end = idx_start + counts
177 n_groups = len(unique_keys)
179 # ---------- Preallocated output path ----------
180 if out is not None:
181 if len(out) < n_groups:
182 raise ValueError(f"Out array too small: need {n_groups}, have {len(out)}")
184 for out_idx, (start, end) in enumerate(zip(idx_start, idx_end)):
185 subarrs = tuple(a[start:end] for a in arrs)
186 row = out[out_idx] # writable structured scalar
187 func(row, *subarrs)
188 if out_idx % 100 == 0:
189 print(f"[{datetime.now().isoformat()}] count={out_idx}")
191 return out
193 # ---------- Dict output path ----------
194 results = {}
195 for keyval, start, end in zip(unique_keys, idx_start, idx_end):
196 subarrs = tuple(a[start:end] for a in arrs)
197 results[keyval] = func(*subarrs)
199 return results
202def values_grouped(a: np.ndarray) -> bool:
203 """
204 Return True if each distinct value in 1D array `a`
205 appears in a single contiguous block (all duplicates grouped).
206 """
207 a = np.asarray(a)
208 if a.ndim != 1:
209 raise ValueError("a must be 1D")
210 if a.size == 0:
211 return True
213 # 1) True where a new group starts: first element, or value != previous
214 group_starts = np.concatenate(([True], a[1:] != a[:-1]))
216 # 2) Values for each group (one per contiguous block)
217 group_vals = a[group_starts]
219 # 3) Check that no value appears in more than one group
220 # i.e., all group_vals are unique
221 return np.unique(group_vals).size == group_vals.size
224def earthlocation_from_obscode(obscode: str) -> EarthLocation:
225 """
226 Convert an MPC observatory code (e.g. 'X05') to an EarthLocation,
227 using MPC.get_observatory_codes() columns:
228 Code, Longitude, cos, sin, Name.
229 """
230 tbl = MPC.get_observatory_codes()
231 row = tbl[tbl["Code"] == obscode]
232 if len(row) != 1:
233 raise ValueError(f"Unknown or ambiguous obscode {obscode!r}")
234 row = row[0]
236 # Handle missing ground positions (spacecraft, etc.)
237 if ma.is_masked(row["Longitude"]) or ma.is_masked(row["cos"]) or ma.is_masked(row["sin"]):
238 raise ValueError(f"Obscode {obscode!r} has no ground position (spacecraft?)")
240 lon = (row["Longitude"] * u.deg).to(u.rad).value # radians
241 rho_cosphi = float(row["cos"])
242 rho_sinphi = float(row["sin"])
244 # Geocentric Cartesian coordinates in Earth radii
245 x_er = rho_cosphi * np.cos(lon)
246 y_er = rho_cosphi * np.sin(lon)
247 z_er = rho_sinphi
249 # Convert Earth radii -> meters
250 x = (x_er * R_earth).to(u.m)
251 y = (y_er * R_earth).to(u.m)
252 z = (z_er * R_earth).to(u.m)
254 return EarthLocation.from_geocentric(x, y, z)
257def observatory_barycentric_posvel(obscode: str, obstime: Time):
258 """
259 Barycentric (ICRS) position and velocity of an observatory given an
260 MPC obscode, using JPL DE440 for the Earth ephemeris.
262 Returns
263 -------
264 r_bary : Quantity, shape (3, ...)
265 Barycentric position in AU.
266 v_bary : Quantity, shape (3, ...)
267 Barycentric velocity in AU/day.
268 """
269 loc = earthlocation_from_obscode(obscode)
271 # Geocentric position & velocity of the site in GCRS (Earth center)
272 obsgeoloc, obsgeovel = loc.get_gcrs_posvel(obstime)
274 # Earth barycentric pos/vel in ICRS using DE440
275 with solar_system_ephemeris.set("de440"):
276 earth_pos, earth_vel = get_body_barycentric_posvel("earth", obstime)
278 # ---- convert everything to SI ----
279 # Earth (ICRS, barycentric)
280 r_earth_si = earth_pos.xyz.to(u.m)
281 v_earth_si = earth_vel.xyz.to(u.m / u.s)
283 # Site (GCRS, geocentric)
284 r_site_geo_si = getattr(obsgeoloc, "xyz", obsgeoloc).to(u.m)
285 v_site_geo_si = getattr(obsgeovel, "xyz", obsgeovel).to(u.m / u.s)
287 # ---- barycentric site vectors in SI ----
288 r_site_bary_si = r_earth_si + r_site_geo_si
289 v_site_bary_si = v_earth_si + v_site_geo_si
291 # ---- convert to AU, AU/day ----
292 r_site_bary = r_site_bary_si.to(u.au)
293 v_site_bary = v_site_bary_si.to(u.au / u.day)
295 return r_site_bary, v_site_bary
298#
299# Serialization
300#
303def struct_to_parquet(
304 arr: np.ndarray,
305 path: str,
306 *,
307 chunk_size: Optional[int] = None,
308 row_group_size: Optional[int] = None,
309) -> None:
310 """
311 Write a large NumPy structured array to a Parquet file using PyArrow.
313 Designed for dtypes like dia_dtype / orbit_dtype with up to ~1e8 rows.
314 """
316 if arr.dtype.names is None:
317 raise TypeError("struct_to_parquet expects a structured NumPy array (dtype.names is None).")
319 n_rows = len(arr)
320 if n_rows == 0:
321 return
323 # Heuristic default chunk size
324 if chunk_size is None:
325 if n_rows <= 10_000_000:
326 chunk_size = n_rows
327 else:
328 chunk_size = 1_000_000
330 def _numpy_to_arrow_array(col: np.ndarray, name: str) -> pa.Array:
331 """
332 Convert a 1D NumPy column view to a PyArrow Array.
334 - Object / unicode / bytes → Arrow string(), with padding stripped
335 for fixed-width 'S'/'U' dtypes.
336 - Numeric types → Arrow infers type from NumPy.
337 """
338 kind = col.dtype.kind
340 if kind == "O":
341 # Already Python objects (str); Arrow will handle them fine.
342 return pa.array(col, type=pa.string())
344 if kind == "S":
345 # Fixed-width bytes, padded with b"\\x00".
346 # Decode + strip trailing NULs.
347 # NOTE: adjust encoding if you know it's not ASCII/UTF-8.
348 decoded = np.char.decode(col, "utf-8", errors="ignore")
349 stripped = np.char.rstrip(decoded, "\x00")
350 return pa.array(stripped, type=pa.string())
352 if kind == "U":
353 # Fixed-width unicode, padded with U+0000.
354 stripped = np.char.rstrip(col, "\x00")
355 return pa.array(stripped, type=pa.string())
357 # Numeric / bool dtypes: let Arrow infer
358 return pa.array(col)
360 def _chunk_to_table(chunk: np.ndarray) -> pa.Table:
361 arrays = []
362 fields = []
364 for name in chunk.dtype.names:
365 col = chunk[name]
366 arr_arrow = _numpy_to_arrow_array(col, name)
367 arrays.append(arr_arrow)
368 fields.append(pa.field(name, arr_arrow.type))
370 schema = pa.schema(fields)
371 return pa.Table.from_arrays(arrays, schema=schema)
373 writer: Optional[pq.ParquetWriter] = None
374 try:
375 for start in range(0, n_rows, chunk_size):
376 end = min(start + chunk_size, n_rows)
377 chunk = arr[start:end]
378 table = _chunk_to_table(chunk)
380 if writer is None:
381 writer = pq.ParquetWriter(path, table.schema)
382 writer.write_table(table, row_group_size=row_group_size)
383 finally:
384 if writer is not None:
385 writer.close()
388# Jupiter's semimajor axis in AU (J2000-ish)
389A_JUP = 5.2044
392def tisserand_jupiter(a, e, inc_deg, a_j=A_JUP):
393 """
394 Compute Tisserand parameter with respect to Jupiter.
396 Parameters
397 ----------
398 a : float or ndarray
399 Semimajor axis of the small body [AU].
400 e : float or ndarray
401 Eccentricity.
402 inc_deg : float or ndarray
403 Inclination [degrees], typically to the ecliptic.
404 a_j : float
405 Semimajor axis of Jupiter [AU]. Default ~5.2044.
407 Returns
408 -------
409 T_J : float or ndarray
410 Tisserand parameter with respect to Jupiter.
411 """
412 inc_rad = np.deg2rad(inc_deg)
413 return (a_j / a) + 2.0 * np.cos(inc_rad) * np.sqrt((a / a_j) * (1.0 - e**2))
416def unpack(df, to_numpy=True):
417 """
418 Return all DataFrame columns as a tuple.
420 Parameters
421 ----------
422 df : pandas.DataFrame
423 Input dataframe.
424 to_numpy : bool, default True
425 If True, return each column as a NumPy array.
426 If False, return each column as a pandas Series.
428 Returns
429 -------
430 tuple
431 Tuple of columns in the original order.
432 """
433 if to_numpy:
434 return tuple(df[col].to_numpy() for col in df.columns)
435 else:
436 return tuple(df[col] for col in df.columns)
439def argjoin(a, v):
440 """
441 Perform an efficient inner join between two 1-D NumPy arrays, returning
442 the index pairs that match by value.
444 Parameters
445 ----------
446 a : ndarray
447 The left-hand array to join on. Must be 1-dimensional.
448 v : ndarray
449 The right-hand array to join on. Must be 1-dimensional.
451 Returns
452 -------
453 aidx : ndarray (int)
454 Indices into `a` selecting the rows that participate in the join.
455 vidx : ndarray (int)
456 Indices into `v` selecting the corresponding matching rows.
458 After the join:
459 a[aidx] == v[vidx]
460 is guaranteed to be true for all elements.
462 Notes
463 -----
464 This function implements a pure NumPy equivalent of an SQL-style
465 INNER JOIN on the key columns `a` and `v`.
467 The algorithm:
469 1. Sort `a` to produce a permutation `i` so that `a[i]` is sorted.
470 2. Use `np.searchsorted(a[i], v)` to find, for each element of `v`,
471 the candidate matching location in the sorted array.
472 3. Map these positions back to the coordinates of the original array `a`
473 using the permutation `i`.
474 4. Filter out non-matches (values in `v` not present in `a`).
475 The remaining pairs form the inner join.
477 Complexity
478 ----------
479 Sorting: O(len(a) log len(a))
480 Searching: O(len(v) log len(a))
481 Total: O(n log n)
483 This is optimal for join-like operations on unsorted arrays in NumPy.
485 Examples
486 --------
487 >>> a = np.array(["b", "a", "c", "b"])
488 >>> v = np.array(["a", "b", "x", "b"])
490 >>> aidx, vidx = argjoin(a, v)
491 >>> a[aidx]
492 array(['a', 'b', 'b'])
493 >>> v[vidx]
494 array(['a', 'b', 'b'])
496 """
497 # 1. Sort a, remembering the permutation
498 i = np.argsort(a)
499 ai = a[i]
501 # 2. Locate each element of v within the sorted array
502 idx = np.searchsorted(ai, v)
504 # Clip to avoid out-of-range indices when v contains values > max(a)
505 idx = np.clip(idx, 0, len(ai) - 1)
507 # 3. Map positions in sorted array back to original array indices
508 aidx_candidate = i[idx]
510 # 4. Keep only true matches (this implements an INNER JOIN)
511 mask = ai[idx] == v
512 vidx = np.flatnonzero(mask)
514 # Final matched indices in a
515 aidx = aidx_candidate[vidx]
517 assert np.all(a[aidx] == v[vidx])
518 return aidx, vidx