Source code for pysus.api.ducklake.client

"""High-level client for DuckLake S3-based dataset catalog.

Provides authentication, catalog synchronization, dataset querying,
and file download capabilities backed by a local DuckDB engine.
"""

from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal

import boto3
import httpx
from anyio import sleep, to_thread
from botocore.config import Config
from pydantic import BaseModel, PrivateAttr, SecretStr
from pysus import CACHEPATH
from pysus.api.models import BaseRemoteClient, BaseRemoteFile
from sqlalchemy import create_engine
from sqlalchemy.orm import contains_eager, joinedload, sessionmaker
from sqlalchemy.pool import StaticPool

from .catalog import CatalogDataset, CatalogFile, DatasetGroup
from .models import DuckDataset, File


[docs] class CatalogDatasetAdapter: """Adapter wrapping a CatalogDataset ORM record for use by File objects. Parameters ---------- catalog_dataset : CatalogDataset The ORM record to wrap. ducklake : DuckLake The parent DuckLake client instance. """ def __init__(self, catalog_dataset: CatalogDataset, ducklake): self.name = catalog_dataset.name self.long_name = catalog_dataset.long_name or "" self.description = catalog_dataset.description or "" self.group_definitions: dict[str, str] = {} self.ducklake = ducklake self.client = ducklake @property def content(self): """Query the DuckLake client for files in this dataset. Returns ------- list List of files belonging to this dataset. """ return self.ducklake.query(dataset=self.name.upper())
[docs] class DatasetGroupAdapter: """Adapter wrapping a DatasetGroup ORM record for use by File objects. Parameters ---------- dataset_group : DatasetGroup The ORM record to wrap. dataset : CatalogDataset The parent dataset. """ def __init__(self, dataset_group: DatasetGroup, dataset): self.name = dataset_group.name self.long_name = dataset_group.long_name or "" self.description = dataset_group.description or "" self.dataset = dataset def __str__(self): """Return the group name as its string representation. Returns ------- str The short name of the group. """ return self.name @property async def files(self): """Return the list of files in this group. Returns ------- list List of file objects in this group. """ return [] async def _fetch_files(self): """Fetch files from the remote source for this group.""" return []
[docs] async def search(self, **kwargs): """Search for files within this group matching the given criteria. Parameters ---------- ``**kwargs`` Arbitrary filter criteria. Returns ------- list List of matching file objects. """ return []
[docs] class DuckLakeCredentials(BaseModel): """Credentials for authenticating with the S3-compatible object storage. Parameters ---------- access_key : SecretStr The S3 access key ID. secret_key : SecretStr The S3 secret access key. """ access_key: SecretStr secret_key: SecretStr
[docs] class DuckLake(BaseRemoteClient): """Client for the DuckLake S3-based public health dataset catalog. Parameters ---------- endpoint : str, optional S3-compatible object storage endpoint. region : str, optional Storage region name. bucket : str, optional Bucket name containing the catalog. credentials : DuckLakeCredentials, optional Credentials for authenticated S3 operations. engine : object, optional Pre-configured SQLAlchemy engine to reuse. """ endpoint: str = "nbg1.your-objectstorage.com" region: str = "nbg1" bucket: str = "pysus" credentials: DuckLakeCredentials | None = None _cache_dir: Path = PrivateAttr() _catalog_local: Path = PrivateAttr() _catalog_remote: str = "public/catalog.db" _s3_client: Any = PrivateAttr(default=None) _engine: Any = PrivateAttr(default=None) _Session: Any = PrivateAttr(default=None) def __init__(self, engine=None, **data): """Initialize the DuckLake client with an optional existing engine. Parameters ---------- engine : object, optional Pre-configured SQLAlchemy engine instead of creating a new one. ``**data`` Additional fields passed to the Pydantic base model. """ super().__init__(**data) self._engine = engine self._cache_dir = Path(CACHEPATH) / "ducklake" self._cache_dir.mkdir(parents=True, exist_ok=True) self._catalog_local = self._cache_dir / "catalog.db" @property def name(self) -> str: """Return the short name of this client. Returns ------- str The client short name. """ return "DuckLake" @property def long_name(self) -> str: """Return the human-readable name of this client. Returns ------- str The client display name. """ return "PySUS s3 Client" @property def description(self) -> str: """Return a description of this client. Returns ------- str A description string (currently empty). """ return "" # TODO: @property def catalog_path(self) -> Path: """Return the local path to the downloaded catalog database. Returns ------- Path Filesystem path to the local catalog database file. """ return self._catalog_local @property def _catalog_url(self) -> str: """Return the remote URL of the catalog database file.""" return f"https://{self.endpoint}/{self.bucket}/{self._catalog_remote}" @property def _is_authenticated(self) -> bool: """Return whether the client has credentials configured.""" return self.credentials is not None
[docs] async def datasets(self, **kwargs) -> list[DuckDataset]: """Return all datasets from the catalog as DuckDataset instances. Parameters ---------- ``**kwargs`` Additional filter arguments (currently unused). Returns ------- list[DuckDataset] List of all datasets in the catalog. """ if not self._Session: await self.connect() def _fetch(): with self._Session() as session: results = ( session.query(CatalogDataset) .options( joinedload(CatalogDataset.groups).joinedload( DatasetGroup.files ), joinedload(CatalogDataset.files), ) .all() ) session.expunge_all() return results records = await to_thread.run_sync(_fetch) return [DuckDataset(record=rec, client=self) for rec in records]
[docs] async def login( self, access_key: str | None = None, secret_key: str | None = None, **kwargs, ) -> None: """Authenticate with S3 credentials and reconnect to the catalog. Parameters ---------- access_key : str, optional S3 access key ID. If omitted, credentials are cleared. secret_key : str, optional S3 secret access key. If omitted, credentials are cleared. ``**kwargs`` Additional arguments (currently unused). """ if access_key and secret_key: self.credentials = DuckLakeCredentials( access_key=SecretStr(access_key), secret_key=SecretStr(secret_key), ) else: self.credentials = None await self.connect(force=True) if self._is_authenticated: self._s3_client = await to_thread.run_sync( self._get_s3_client, )
def _setup_engine(self): """Create and configure the DuckDB engine with S3 settings.""" engine = create_engine( f"duckdb:///{self._catalog_local}", poolclass=StaticPool, ) with engine.connect() as conn: conn.exec_driver_sql("INSTALL ducklake; LOAD ducklake;") has_pysus = conn.exec_driver_sql( """ SELECT 1 FROM information_schema.schemata WHERE schema_name = 'pysus' """ ).fetchone() if has_pysus: conn.exec_driver_sql("SET search_path='pysus,main';") else: conn.exec_driver_sql("SET search_path='main';") s3_cfg = { "s3_endpoint": self.endpoint, "s3_region": self.region, "s3_url_style": "path", "s3_use_ssl": "true", } if self.credentials and self._is_authenticated: s3_cfg["s3_access_key_id"] = ( self.credentials.access_key.get_secret_value() ) s3_cfg["s3_secret_access_key"] = ( self.credentials.secret_key.get_secret_value() ) for key, value in s3_cfg.items(): conn.exec_driver_sql(f"SET {key}='{value}'") conn.commit() return engine
[docs] async def connect(self, force: bool = False): """Connect to the catalog, downloading it first if necessary. Parameters ---------- force : bool, optional Whether to re-download and re-connect even if already connected. """ if self._engine and not force: if not self._Session: self._Session = sessionmaker(bind=self._engine) return await self._load_catalog() self._engine = await to_thread.run_sync(self._setup_engine) self._Session = sessionmaker(bind=self._engine)
[docs] async def close(self): """Dispose the engine, then upload the catalog if authenticated. Raises ------ PermissionError If the client is not authenticated but an upload is required. """ if self._engine: await to_thread.run_sync(self._engine.dispose) self._engine = None self._Session = None if self._is_authenticated: await self._upload_catalog() self._s3_client = None
async def _download_file( self, file: BaseRemoteFile, output: Path, callback: Callable[[int, int], None] | None = None, ) -> Path: """Download a single file from object storage to the local path.""" if not isinstance(file, File): raise ValueError("FTP File was not properly instantiated") url = f"https://{self.endpoint}/{self.bucket}/{file.record.path}" async with httpx.AsyncClient(follow_redirects=True) as client: async with client.stream("GET", url) as r: r.raise_for_status() total = int(r.headers.get("Content-Length", 0)) downloaded = 0 with open(output, "wb") as f: async for chunk in r.aiter_bytes(chunk_size=1024 * 1024): await to_thread.run_sync(f.write, chunk) downloaded += len(chunk) if callback: callback(downloaded, total) return output async def _download_catalog(self, client: httpx.AsyncClient): """Download the catalog database from remote storage with retries.""" max_retries = 5 for attempt in range(max_retries): try: async with client.stream("GET", self._catalog_url) as r: r.raise_for_status() with open(self._catalog_local, "wb") as f: async for chunk in r.aiter_bytes( chunk_size=1024 * 1024, ): await to_thread.run_sync(f.write, chunk) return except OSError as e: if attempt < max_retries - 1: await sleep(1) else: raise e def _get_s3_client(self): """Create and return a boto3 S3 client for the configured endpoint.""" if not self.credentials: raise ConnectionError("S3 Credentials not found") return boto3.client( "s3", endpoint_url=f"https://{self.endpoint}", aws_access_key_id=self.credentials.access_key.get_secret_value(), aws_secret_access_key=( self.credentials.secret_key.get_secret_value() ), region_name=self.region, config=Config(signature_version="s3v4"), ) async def _load_catalog(self): """Download remote catalog if the local copy is outdated or missing.""" async with httpx.AsyncClient(follow_redirects=True) as client: local_size = -1 if self._catalog_local.exists(): try: local_size = self._catalog_local.stat().st_size except OSError: pass try: head = await client.head(self._catalog_url) head.raise_for_status() remote_size = int(head.headers.get("content-length", 0)) except Exception: # noqa: B902 remote_size = 0 if remote_size != local_size: await self._download_catalog(client) async def _upload_catalog(self): """Upload the local catalog database to remote storage.""" if not self._is_authenticated: raise PermissionError( "Admin credentials required to upload catalog.", ) def _upload(): self._s3_client.upload_file( str(self._catalog_local), self.bucket, self._catalog_remote, ) await to_thread.run_sync(_upload)
[docs] async def query( self, client: Literal["FTP", "DadosGov"] | None = None, dataset: str | None = None, group: str | None = None, state: str | None = None, year: int | None = None, month: int | None = None, ) -> list[File]: """Filter catalog files by client, dataset, group, state, year. Parameters ---------- client : Literal["FTP", "DadosGov"], optional Source client to filter by. dataset : str, optional Dataset name to filter by. group : str, optional Group name pattern to filter by (case-insensitive ILIKE). state : str, optional Two-letter state code to filter by. year : int, optional Year to filter by. month : int, optional Month to filter by. Returns ------- list[:class:`~pysus.api.ducklake.models.File`] List of matching file objects. """ if not self._Session: await self.connect() def _query(): with self._Session() as session: q = session.query(CatalogFile) if dataset: q = ( q.join(CatalogFile.dataset) .options(contains_eager(CatalogFile.dataset)) .filter(CatalogDataset.name == dataset.lower()) ) else: q = q.options(joinedload(CatalogFile.dataset)) if group: q = ( q.join(CatalogFile.group) .options(contains_eager(CatalogFile.group)) .filter(DatasetGroup.name.ilike(group)) ) else: q = q.options(joinedload(CatalogFile.group)) if state: q = q.filter(CatalogFile.state == state.upper()) if year: q = q.filter(CatalogFile.year == year) if month: q = q.filter(CatalogFile.month == month) results = q.all() session.expunge_all() return results records = await to_thread.run_sync(_query) if client: prefix = f"public/data/{client.lower()}/" records = [r for r in records if r.path.startswith(prefix)] else: ftp = [r for r in records if r.path.startswith("public/data/ftp/")] dadosgov = [ r for r in records if r.path.startswith("public/data/dadosgov/") ] ftp_keys = set() for r in ftp: stem = Path(r.path).stem key = (r.dataset_id, r.year, r.month, stem) ftp_keys.add(key) def has_ftp_match(r): stem = Path(r.path).stem if stem.endswith(".csv"): stem = stem[:-4] key = (r.dataset_id, r.year, r.month, stem) return key in ftp_keys records = ftp + [r for r in dadosgov if not has_ftp_match(r)] return [ File( path=r.path, record=r, dataset=CatalogDatasetAdapter(r.dataset, self), group=( DatasetGroupAdapter(r.group, r.dataset) if r.group else None ), ) for r in records ]