def diff(
old: str | Path,
new: str | Path,
*,
keys: str | list[str],
label_time: str,
atol: float = DEFAULT_ATOL,
rtol: float = DEFAULT_RTOL,
) -> DiffResult:
"""Compare two training datasets."""
keys_list = [keys] if isinstance(keys, str) else list(keys)
conn = duckdb.connect()
try:
conn.execute(
f"CREATE TEMP TABLE __old AS SELECT * FROM read_parquet({_ql(old)})"
)
conn.execute(
f"CREATE TEMP TABLE __new AS SELECT * FROM read_parquet({_ql(new)})"
)
old_count = conn.execute("SELECT COUNT(*) FROM __old").fetchone()[0]
new_count = conn.execute("SELECT COUNT(*) FROM __new").fetchone()[0]
old_cols = {c[0] for c in conn.execute("DESCRIBE __old").fetchall()}
new_cols = {c[0] for c in conn.execute("DESCRIBE __new").fetchall()}
result = DiffResult(old_rows=old_count, new_rows=new_count)
meta_cols = set(keys_list) | {label_time}
added = new_cols - old_cols
removed = old_cols - new_cols
common = (old_cols & new_cols) - meta_cols
for col in sorted(added):
result.schema_changes.append(
{"type": "+", "column": col, "detail": "(new column)"}
)
for col in sorted(removed):
result.schema_changes.append(
{"type": "-", "column": col, "detail": "(removed)"}
)
key_join = " AND ".join(f"o.{_qi(k)} = n.{_qi(k)}" for k in keys_list)
key_join += f" AND o.{_qi(label_time)} = n.{_qi(label_time)}"
for col in sorted(common):
qc = _qi(col)
try:
# Try tolerance-aware numeric comparison first
try:
change_sql = (
f"SELECT COUNT(*) FROM __old o JOIN __new n ON {key_join} "
f"WHERE o.{qc} IS NOT NULL AND n.{qc} IS NOT NULL "
f"AND ABS(CAST(o.{qc} AS DOUBLE) - CAST(n.{qc} AS DOUBLE)) "
f"> {atol} + {rtol} * ABS(CAST(n.{qc} AS DOUBLE))"
)
changed = conn.execute(change_sql).fetchone()[0]
# Also count null-vs-non-null differences
null_diff_sql = (
f"SELECT COUNT(*) FROM __old o JOIN __new n ON {key_join} "
f"WHERE (o.{qc} IS NULL) != (n.{qc} IS NULL)"
)
changed += conn.execute(null_diff_sql).fetchone()[0]
except (duckdb.Error, duckdb.ConversionException):
# Non-numeric: fall back to exact comparison
change_sql = (
f"SELECT COUNT(*) FROM __old o JOIN __new n ON {key_join} "
f"WHERE o.{qc} IS DISTINCT FROM n.{qc}"
)
changed = conn.execute(change_sql).fetchone()[0]
if changed > 0:
joined = min(old_count, new_count)
pct = changed / joined if joined > 0 else 0
stats_entry: dict[str, Any] = {
"changed_count": changed,
"changed_pct": pct,
}
# Compute numeric delta stats when possible
try:
delta_sql = (
f"SELECT "
f"AVG(CAST(n.{qc} AS DOUBLE) - CAST(o.{qc} AS DOUBLE)), "
f"MAX(ABS(CAST(n.{qc} AS DOUBLE) - CAST(o.{qc} AS DOUBLE))) "
f"FROM __old o JOIN __new n ON {key_join} "
f"WHERE o.{qc} IS DISTINCT FROM n.{qc}"
)
delta_row = conn.execute(delta_sql).fetchone()
if delta_row and delta_row[0] is not None:
stats_entry["mean_delta"] = float(delta_row[0])
stats_entry["max_delta"] = float(delta_row[1])
except (duckdb.Error, TypeError):
pass # Non-numeric column, delta stats not applicable
result.value_changes[col] = stats_entry
result.schema_changes.append(
{
"type": "~",
"column": col,
"detail": f"{changed} values changed ({pct:.1%})",
}
)
else:
result.schema_changes.append(
{"type": "=", "column": col, "detail": "unchanged"}
)
except (duckdb.Error, TypeError) as exc:
logger.warning("Column comparison failed for %s: %s", col, exc)
result.schema_changes.append(
{"type": "?", "column": col, "detail": "comparison failed"}
)
return result
finally:
conn.close()