"""Main orchestrator for the PySUS data pipeline.
Manages file downloads, local state tracking, catalog attachment,
Parquet conversion, and query execution across multiple backends.
"""
import enum
from collections.abc import Callable
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import anyio
import duckdb
import pandas as pd
from pysus import CACHEPATH
from sqlalchemy import DateTime, Enum, Integer, String, create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
from sqlalchemy.pool import NullPool
from .dadosgov import DadosGovClient
from .ducklake.client import DuckLake
from .extensions import Parquet
from .ftp import FTPClient
from .models import BaseLocalFile, BaseRemoteFile
if TYPE_CHECKING:
from duckdb import DuckDBPyConnection
[docs]
class Base(DeclarativeBase):
"""Base declarative class for SQLAlchemy ORM models."""
[docs]
class DownloadStatus(enum.Enum):
"""Download status values tracked for each local file."""
PENDING = "pending"
DOWNLOADING = "downloading"
COMPLETED = "completed"
FAILED = "failed"
MISSING = "missing"
[docs]
class LocalFileState(Base):
"""ORM model tracking the state of a downloaded local file."""
__tablename__ = "local_file_state"
path: Mapped[str] = mapped_column(String, primary_key=True)
remote_path: Mapped[str] = mapped_column(String, nullable=False)
client_name: Mapped[str] = mapped_column(String, nullable=False)
year: Mapped[int | None] = mapped_column(Integer, nullable=True)
month: Mapped[int | None] = mapped_column(Integer, nullable=True)
state: Mapped[str | None] = mapped_column(String, nullable=True)
group: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[DownloadStatus] = mapped_column(
Enum(DownloadStatus),
default=DownloadStatus.PENDING,
)
sha256: Mapped[str | None] = mapped_column(String, nullable=True)
last_synced: Mapped[datetime] = mapped_column(
DateTime,
default=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
)
[docs]
class PySUS:
"""Central orchestrator for downloading and querying PySUS datasets."""
def __init__(self, db_path: Path = CACHEPATH / "config.db"):
"""Initialize the PySUS orchestrator.
Creates a SQLAlchemy engine backed by DuckDB, initializes the
schema, and sets up the session factory.
Parameters
----------
db_path : Path, optional
Path to the DuckDB database file. Defaults to
``CACHEPATH / "config.db"``.
"""
db_path = Path(db_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
self.cachepath = db_path.parent
self.engine = create_engine(
f"duckdb:///{db_path.resolve().as_posix()}",
poolclass=NullPool,
)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
self._ducklake: DuckLake | None = None
self._ftp: FTPClient | None = None
self._dadosgov: DadosGovClient | None = None
async def __aenter__(self):
"""Set up DuckLake catalog and return self as async context manager."""
self._ducklake = DuckLake()
await self._ducklake._load_catalog()
self._attach_client_catalog(
"ducklake",
str(self._ducklake.catalog_path),
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Clean up all client connections and dispose of the engine."""
if self._ducklake:
await self._ducklake.close()
if self._ftp:
await self._ftp.close()
if self._dadosgov:
await self._dadosgov.close()
self.engine.dispose()
[docs]
async def get_ducklake(self) -> DuckLake:
"""Return the DuckLake client, initializing it lazily if needed."""
if self._ducklake is None:
self._ducklake = DuckLake()
await self._ducklake._load_catalog()
self._attach_client_catalog(
"ducklake",
str(self._ducklake.catalog_path),
)
return self._ducklake
[docs]
async def get_dadosgov(self, access_token: str | None) -> DadosGovClient:
"""Return the DadosGov client, connecting lazily if needed."""
if self._dadosgov is None:
self._dadosgov = DadosGovClient()
await self._dadosgov.connect(token=access_token)
return self._dadosgov
[docs]
async def get_ftp(self) -> FTPClient:
"""Return the FTP client, connecting lazily if needed."""
if self._ftp is None:
self._ftp = FTPClient()
await self._ftp.connect()
return self._ftp
[docs]
async def get_local_file(
self,
file: BaseRemoteFile,
) -> BaseLocalFile | None:
"""Look up a previously downloaded file by its remote path."""
from pysus.api.extensions import ExtensionFactory
client_name = file.client.name.lower()
remote_path = file.path
with self.Session() as session:
records = (
session.query(LocalFileState)
.filter_by(
remote_path=str(remote_path),
client_name=str(client_name),
status=DownloadStatus.COMPLETED,
)
.all()
)
if not records:
return None
parquet_version = next(
(r for r in records if str(r.path).endswith(".parquet")), None
)
record = parquet_version or records[0]
return await ExtensionFactory.instantiate(str(record.path))
def _attach_client_catalog(self, name: str, path: str):
"""Attach an external DuckDB catalog to the engine if not attached."""
abs_path = str(Path(path).absolute())
with self.engine.connect() as conn:
q = "SELECT database_name FROM duckdb_databases() WHERE path = ?"
existing = conn.exec_driver_sql(q, (abs_path,)).fetchone()
if not existing:
conn.exec_driver_sql(
f"ATTACH '{abs_path}' AS {name} (READ_ONLY)",
)
def _get_dest_path(self, file: BaseRemoteFile) -> Path:
"""Build the local filesystem path for a given remote file."""
client_name = file.client.name.lower()
dataset_name = file.dataset.name.lower()
group_name = ""
if hasattr(file, "group") and file.group:
group_name = getattr(file.group, "name", "")
base_dir = self.cachepath / "downloads" / client_name / dataset_name
if group_name:
return base_dir / group_name / file.basename
return base_dir / file.basename
async def _update_state(
self,
local_path: Path,
remote_path: str,
client_name: str,
status: DownloadStatus,
year: int | None = None,
month: int | None = None,
state: str | None = None,
group: str | None = None,
):
"""Create or update the LocalFileState record for a file."""
with self.Session() as session:
record = (
session.query(LocalFileState)
.filter_by(
path=str(local_path),
)
.first()
)
if not record:
record = LocalFileState(
path=str(local_path),
remote_path=str(remote_path),
client_name=client_name,
year=year,
month=month,
state=state,
group=group,
)
session.add(record)
record.status = status
record.last_synced = datetime.now(timezone.utc).replace(tzinfo=None)
session.commit()
[docs]
async def download(
self,
file: BaseRemoteFile,
token: str | None = None,
callback: Callable | None = None,
timeout: float | None = None,
) -> BaseLocalFile:
"""Download a remote file and return a local file handle.
Skips re-download if a matching local copy already exists.
Parameters
----------
file : BaseRemoteFile
The remote file to download.
token : str, optional
Access token for authenticated clients (e.g. DadosGov).
callback : Callable, optional
Progress callback invoked during the download.
timeout : float, optional
Maximum seconds to wait for the download. ``None`` (default)
means no timeout.
Returns
-------
BaseLocalFile
The downloaded file wrapped in the appropriate handler.
Raises
------
ValueError
If the file's client is not recognised.
RuntimeError
If the download fails for any reason.
"""
from pysus.api.extensions import ExtensionFactory
existing_local = await self.get_local_file(file)
if existing_local and existing_local.path.exists():
if existing_local.size == file.size:
return existing_local
await self._delete_record(str(existing_local.path))
existing_local.path.unlink(missing_ok=True)
client_name = file.client.name.lower()
remote_path = file.path
local_path = self._get_dest_path(file)
local_path.parent.mkdir(parents=True, exist_ok=True)
await self._update_state(
local_path,
str(remote_path),
client_name,
DownloadStatus.DOWNLOADING,
)
client: DuckLake | FTPClient | DadosGovClient
try:
if client_name == "ducklake":
client = await self.get_ducklake()
elif client_name == "ftp":
client = await self.get_ftp()
elif client_name == "dadosgov":
client = await self.get_dadosgov(token)
else:
raise ValueError(
f"No download logic for client: {client_name}",
)
if timeout is not None:
with anyio.fail_after(timeout):
await client._download_file(file, local_path, callback)
else:
await client._download_file(file, local_path, callback)
await self._update_state(
local_path=local_path,
remote_path=str(remote_path),
client_name=client_name,
status=DownloadStatus.DOWNLOADING,
year=file.year,
month=file.month,
state=file.state,
group=getattr(file.group, "name", None),
)
return await ExtensionFactory.instantiate(local_path)
except Exception as e: # noqa: B902
await self._update_state(
local_path,
str(remote_path),
client_name,
DownloadStatus.FAILED,
)
local_path.unlink(missing_ok=True)
raise RuntimeError(
f"Unexpected error downloading {file.basename}: {e}",
) from e
async def _delete_record(self, path: str):
"""Delete a LocalFileState record from the database."""
with self.Session() as session:
record = session.query(LocalFileState).filter_by(path=path).first()
if record:
session.delete(record)
session.commit()
[docs]
async def download_to_parquet(
self,
file: BaseRemoteFile,
token: str | None = None,
callback: Callable[[int, int], None] | None = None,
timeout: float | None = None,
add_dv: bool = True,
) -> Parquet:
"""Download a file and convert it to Parquet format.
Parameters
----------
file : BaseRemoteFile
The remote file to download and convert.
token : str, optional
Access token for authenticated clients.
callback : Callable[[int, int], None], optional
Progress callback.
timeout : float, optional
Maximum seconds to wait for the download.
add_dv : bool, optional
Whether to apply the IBGE verification digit on load
(default True).
Returns
-------
Parquet
The converted Parquet file handler.
Raises
------
NotImplementedError
If the downloaded file type cannot be converted to Parquet.
"""
local_file = await self.download(
file=file,
token=token,
callback=callback,
timeout=timeout,
)
if hasattr(local_file, "to_parquet"):
original_path = local_file.path
parquet_file = await local_file.to_parquet(callback=callback)
parquet_file.add_dv = add_dv
await self._update_state(
local_path=parquet_file.path,
remote_path=str(file.path),
client_name=file.client.name.lower(),
status=DownloadStatus.COMPLETED,
year=file.year,
month=file.month,
state=file.state,
group=getattr(file.group, "name", None),
)
if original_path.exists() and original_path != parquet_file.path:
original_path.unlink()
await self._delete_record(str(original_path))
return parquet_file
raise NotImplementedError(
f"{local_file} can't be converted to Parquet",
)
[docs]
def get_local_hierarchy(self):
"""Build a nested dict of cached files grouped by client and dataset.
Returns
-------
dict
Nested dict keyed by
``{client: {dataset: {group: [files]}}}``.
"""
with self.Session() as session:
records = session.query(LocalFileState).all()
hierarchy = {}
for r in records:
client = r.client_name.upper()
path_obj = Path(str(r.path))
parts = path_obj.parts
dataset = parts[-2] if len(parts) > 2 else "Other"
has_group = getattr(r, "group", None) is not None
if path_obj.is_file() and len(parts) > 3:
dataset = parts[-2] if has_group else parts[-3]
client_dict = hierarchy.setdefault(client, {})
ds_dict = client_dict.setdefault(dataset, {})
group_list = ds_dict.setdefault(r.group or "", [])
group_list.append(
{
"name": path_obj.name,
"status": r.status,
"path": r.path,
"record": r,
}
)
return hierarchy
[docs]
def get_completed_remote_paths(self) -> set[str]:
"""Return remote paths for all successfully downloaded files."""
with self.Session() as session:
records = (
session.query(LocalFileState.remote_path)
.filter(LocalFileState.status == DownloadStatus.COMPLETED)
.all()
)
return {str(r.remote_path) for r in records}
[docs]
async def query(
self,
client: Literal["DadosGov", "FTP"] | None = None,
dataset: str | None = None,
group: str | None = None,
state: str | None = None,
year: int | None = None,
month: int | None = None,
):
"""Query available datasets through the DuckLake catalog."""
if self._ducklake is None:
await self.get_ducklake()
if self._ducklake is not None:
return await self._ducklake.query(
client=client,
dataset=dataset,
group=group,
state=state,
year=year,
month=month,
)
[docs]
def read_parquet(
self,
paths: list[Path],
sql: str | None = None,
mode: Literal["union", "intersection", "strict"] = "union",
add_dv: bool = True,
) -> "DuckDBPyConnection | pd.DataFrame":
"""Read Parquet files with optional schema handling and SQL filter.
Parameters
----------
paths : list of Path
One or more Parquet file paths to read.
sql : str, optional
Optional SQL filter expression applied to the result.
mode : {"union", "intersection", "strict"}, optional
Schema resolution mode (default ``"union"``).
add_dv : bool, optional
When True, automatically applies the IBGE verification digit to
municipality code columns. If matching columns are found, a
DataFrame is returned instead of a ``DuckDBPyConnection``.
Returns
-------
DuckDBPyConnection or pd.DataFrame
The query result.
Raises
------
ValueError
If no paths are provided, or if the schema mode is ``"strict"``
and the files have differing schemas.
"""
from pysus.api.utils import add_dv as _add_dv_fn
from pysus.api.utils import is_geocode_column
if not paths:
raise ValueError("No paths provided")
def get_columns(path: Path) -> set[tuple[str, str]]:
"""Return the schema of a Parquet file as (name, type) pairs."""
result = duckdb.execute(f"SELECT * FROM '{path}' LIMIT 0")
return {(col[0], str(col[1])) for col in result.description}
if len(paths) == 1:
query = f"SELECT * FROM '{paths[0]}'"
else:
paths_str = ", ".join(f"'{p}'" for p in paths)
query = f"SELECT * FROM read_parquet([{paths_str}])"
schemas = [get_columns(p) for p in paths]
common_columns = set.intersection(*schemas) if schemas else set()
if mode == "strict":
for i, schema in enumerate(schemas):
if schema != schemas[0]:
raise ValueError(
f"Schema mismatch: file {i} has columns "
f"{[c[0] for c in schema]}, "
f"expected {[c[0] for c in schemas[0]]}"
)
elif mode == "intersection":
if not common_columns:
return duckdb.execute("SELECT * WHERE 1=0")
cols = ", ".join(f'"{c[0]}"' for c in sorted(common_columns))
paths_str = ", ".join(f"'{p}'" for p in paths)
query = f"SELECT {cols} FROM read_parquet([{paths_str}])"
else:
paths_str = ", ".join(f"'{p}'" for p in paths)
query = (
f"SELECT * FROM read_parquet([{paths_str}], union_by_name=True)"
)
if sql:
if sql.upper().startswith("SELECT"):
query = sql.replace("FROM t", f"FROM ({query}) AS t")
else:
query = f"SELECT {sql} FROM ({query}) AS t"
base = duckdb.execute(query)
if not add_dv:
return base
geocode_cols = [
col[0] for col in base.description if is_geocode_column(col[0])
]
if not geocode_cols:
return base
try:
duckdb.create_function(
"__pysus_add_dv",
_add_dv_fn,
null_handling="special",
)
except duckdb.NotImplementedException:
pass
selects = [
(
f'__pysus_add_dv("{c[0]}") AS "{c[0]}"'
if c[0] in geocode_cols
else f'"{c[0]}"'
)
for c in base.description
]
query = f"SELECT {', '.join(selects)} FROM ({query}) AS _t"
return duckdb.execute(query)