Skip to content

diff()

Compare two training datasets for schema drift and value changes.

diff(old, new, *, keys, label_time, atol=DEFAULT_ATOL, rtol=DEFAULT_RTOL)

Compare two training datasets.

Source code in src/timefence/engine.py
def diff(
    old: str | Path,
    new: str | Path,
    *,
    keys: str | list[str],
    label_time: str,
    atol: float = DEFAULT_ATOL,
    rtol: float = DEFAULT_RTOL,
) -> DiffResult:
    """Compare two training datasets."""
    keys_list = [keys] if isinstance(keys, str) else list(keys)

    conn = duckdb.connect()
    try:
        conn.execute(
            f"CREATE TEMP TABLE __old AS SELECT * FROM read_parquet({_ql(old)})"
        )
        conn.execute(
            f"CREATE TEMP TABLE __new AS SELECT * FROM read_parquet({_ql(new)})"
        )

        old_count = conn.execute("SELECT COUNT(*) FROM __old").fetchone()[0]
        new_count = conn.execute("SELECT COUNT(*) FROM __new").fetchone()[0]

        old_cols = {c[0] for c in conn.execute("DESCRIBE __old").fetchall()}
        new_cols = {c[0] for c in conn.execute("DESCRIBE __new").fetchall()}

        result = DiffResult(old_rows=old_count, new_rows=new_count)

        meta_cols = set(keys_list) | {label_time}
        added = new_cols - old_cols
        removed = old_cols - new_cols
        common = (old_cols & new_cols) - meta_cols

        for col in sorted(added):
            result.schema_changes.append(
                {"type": "+", "column": col, "detail": "(new column)"}
            )
        for col in sorted(removed):
            result.schema_changes.append(
                {"type": "-", "column": col, "detail": "(removed)"}
            )

        key_join = " AND ".join(f"o.{_qi(k)} = n.{_qi(k)}" for k in keys_list)
        key_join += f" AND o.{_qi(label_time)} = n.{_qi(label_time)}"

        for col in sorted(common):
            qc = _qi(col)
            try:
                # Try tolerance-aware numeric comparison first
                try:
                    change_sql = (
                        f"SELECT COUNT(*) FROM __old o JOIN __new n ON {key_join} "
                        f"WHERE o.{qc} IS NOT NULL AND n.{qc} IS NOT NULL "
                        f"AND ABS(CAST(o.{qc} AS DOUBLE) - CAST(n.{qc} AS DOUBLE)) "
                        f"> {atol} + {rtol} * ABS(CAST(n.{qc} AS DOUBLE))"
                    )
                    changed = conn.execute(change_sql).fetchone()[0]
                    # Also count null-vs-non-null differences
                    null_diff_sql = (
                        f"SELECT COUNT(*) FROM __old o JOIN __new n ON {key_join} "
                        f"WHERE (o.{qc} IS NULL) != (n.{qc} IS NULL)"
                    )
                    changed += conn.execute(null_diff_sql).fetchone()[0]
                except (duckdb.Error, duckdb.ConversionException):
                    # Non-numeric: fall back to exact comparison
                    change_sql = (
                        f"SELECT COUNT(*) FROM __old o JOIN __new n ON {key_join} "
                        f"WHERE o.{qc} IS DISTINCT FROM n.{qc}"
                    )
                    changed = conn.execute(change_sql).fetchone()[0]

                if changed > 0:
                    joined = min(old_count, new_count)
                    pct = changed / joined if joined > 0 else 0

                    stats_entry: dict[str, Any] = {
                        "changed_count": changed,
                        "changed_pct": pct,
                    }

                    # Compute numeric delta stats when possible
                    try:
                        delta_sql = (
                            f"SELECT "
                            f"AVG(CAST(n.{qc} AS DOUBLE) - CAST(o.{qc} AS DOUBLE)), "
                            f"MAX(ABS(CAST(n.{qc} AS DOUBLE) - CAST(o.{qc} AS DOUBLE))) "
                            f"FROM __old o JOIN __new n ON {key_join} "
                            f"WHERE o.{qc} IS DISTINCT FROM n.{qc}"
                        )
                        delta_row = conn.execute(delta_sql).fetchone()
                        if delta_row and delta_row[0] is not None:
                            stats_entry["mean_delta"] = float(delta_row[0])
                            stats_entry["max_delta"] = float(delta_row[1])
                    except (duckdb.Error, TypeError):
                        pass  # Non-numeric column, delta stats not applicable

                    result.value_changes[col] = stats_entry
                    result.schema_changes.append(
                        {
                            "type": "~",
                            "column": col,
                            "detail": f"{changed} values changed ({pct:.1%})",
                        }
                    )
                else:
                    result.schema_changes.append(
                        {"type": "=", "column": col, "detail": "unchanged"}
                    )
            except (duckdb.Error, TypeError) as exc:
                logger.warning("Column comparison failed for %s: %s", col, exc)
                result.schema_changes.append(
                    {"type": "?", "column": col, "detail": "comparison failed"}
                )

        return result
    finally:
        conn.close()

Example

result = timefence.diff(
    old="train_v1.parquet",
    new="train_v2.parquet",
    keys=["user_id"],
    label_time="label_time",
    atol=1e-10,
    rtol=1e-7,
)

Returns: DiffResult

Attribute Type Description
.old_rows int Row count in old dataset.
.new_rows int Row count in new dataset.
.schema_changes list[dict] Schema changes: type (+ added, - removed, ~ changed, = unchanged, ? comparison failed), column, detail.
.value_changes dict[str, dict] Per-column: changed_count, changed_pct, mean_delta, max_delta. mean_delta and max_delta are only present for numeric columns.