Source code for pysus.api.ducklake.client

"""High-level client for DuckLake S3-based public health dataset catalog.

Provides authentication, dataset discovery, and file download
capabilities backed by per-dataset DuckDB engines.
"""

import asyncio
from collections.abc import Callable
from pathlib import Path

from anyio import to_thread
from pydantic import Field, PrivateAttr, SecretStr
from pysus.api.models import BaseRemoteClient, BaseRemoteFile
from pysus.api.types import DUCKLAKE

from .catalog.adapters import (
    CatalogAdapter,
    ColumnsAdapter,
    DatasetAdapter,
    DuckLakeCredentials,
)
from .catalog.orm.default import Dataset
from .functional import download_http
from .models import DuckDataset, File


[docs] class DuckLake(BaseRemoteClient): credentials: DuckLakeCredentials | None = None update_on_close: bool = Field(default=False, exclude=True) _datasets: list[DuckDataset] = PrivateAttr(default_factory=list) _catalog_adap: CatalogAdapter = PrivateAttr() _columns_adap: ColumnsAdapter = PrivateAttr() def __init__( self, engine=None, columns_engine=None, update_on_close: bool = False, **data, ) -> None: super().__init__(**data) self.update_on_close = update_on_close self._catalog_adap = CatalogAdapter( engine=engine, credentials=self.credentials, update_on_close=self.update_on_close, ) self._columns_adap = ColumnsAdapter( engine=columns_engine, credentials=self.credentials, update_on_close=self.update_on_close, ) @property def name(self) -> str: return DUCKLAKE @property def long_name(self) -> str: return "PySUS s3 Client" @property def description(self) -> str: return "" @property def catalog_path(self) -> Path: return self._catalog_adap.db_local @property def columns_path(self) -> Path: return self._columns_adap.db_local
[docs] async def datasets(self, **kwargs) -> list[DuckDataset]: def _fetch(): with self._catalog_adap.get_session() as session: results = session.query(Dataset).all() session.expunge_all() return results duck_datasets: list[DuckDataset] = [] async with self._catalog_adap: records = await to_thread.run_sync(_fetch) for rec in records: dataset_adapter = DatasetAdapter( name=str(rec.name), dataset_id=int(rec.id), credentials=self.credentials, update_on_close=self.update_on_close, ) duck_datasets.append( DuckDataset( record=rec, client=self, adapter=dataset_adapter, update_on_close=self.update_on_close, ) ) self._datasets = duck_datasets return duck_datasets
[docs] async def login(self, **kwargs) -> None: access_key = kwargs.get("access_key") secret_key = kwargs.get("secret_key") if not access_key or not secret_key: raise ValueError( "DuckLake authentication requires 'access_key' and 'secret_key'" ) self.credentials = DuckLakeCredentials( access_key=SecretStr(access_key), secret_key=SecretStr(secret_key), ) self._catalog_adap.credentials = self.credentials self._columns_adap.credentials = self.credentials await self._catalog_adap.connect(force=True) await self._columns_adap.connect(force=True)
[docs] async def connect(self, force: bool = False) -> None: await self._catalog_adap.connect(force=force) await self._columns_adap.connect(force=force)
[docs] async def close(self, update_catalog: bool | None = None) -> None: should_update = ( self.update_on_close if update_catalog is None else update_catalog ) for ds in self._datasets: await ds.close(update_catalog=should_update) await self._catalog_adap.close(update=should_update) await self._columns_adap.close(update=should_update)
[docs] async def download( self, file: BaseRemoteFile, output: Path, callback: Callable[[int, int], None] | None = None, ) -> Path: if not isinstance(file, File): raise ValueError("DuckLake File was not properly instantiated") await download_http( remote_path=file.record.path, local_path=output, callback=callback, ) return output
async def __aenter__(self): await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close(update_catalog=None) def __del__(self) -> None: if not hasattr(self, "_catalog_adap") or not hasattr( self, "_columns_adap" ): return try: loop = asyncio.get_running_loop() if loop.is_running(): loop.create_task(self.close(update_catalog=False)) except RuntimeError: try: asyncio.run(self.close(update_catalog=False)) except Exception: # noqa pass except Exception: # noqa pass
DuckDataset.model_rebuild(_types_namespace={"DuckLake": DuckLake})