"""High-level client for DuckLake S3-based dataset catalog.
Provides authentication, catalog synchronization, dataset querying,
and file download capabilities backed by a local DuckDB engine.
"""
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal
import boto3
import httpx
from anyio import sleep, to_thread
from botocore.config import Config
from pydantic import BaseModel, PrivateAttr, SecretStr
from pysus import CACHEPATH
from pysus.api.models import BaseRemoteClient, BaseRemoteFile
from sqlalchemy import create_engine
from sqlalchemy.orm import contains_eager, joinedload, sessionmaker
from sqlalchemy.pool import StaticPool
from .catalog import CatalogDataset, CatalogFile, DatasetGroup
from .models import DuckDataset, File
[docs]
class CatalogDatasetAdapter:
"""Adapter wrapping a CatalogDataset ORM record for use by File objects.
Parameters
----------
catalog_dataset : CatalogDataset
The ORM record to wrap.
ducklake : DuckLake
The parent DuckLake client instance.
"""
def __init__(self, catalog_dataset: CatalogDataset, ducklake):
self.name = catalog_dataset.name
self.long_name = catalog_dataset.long_name or ""
self.description = catalog_dataset.description or ""
self.group_definitions: dict[str, str] = {}
self.ducklake = ducklake
self.client = ducklake
@property
def content(self):
"""Query the DuckLake client for files in this dataset.
Returns
-------
list
List of files belonging to this dataset.
"""
return self.ducklake.query(dataset=self.name.upper())
[docs]
class DatasetGroupAdapter:
"""Adapter wrapping a DatasetGroup ORM record for use by File objects.
Parameters
----------
dataset_group : DatasetGroup
The ORM record to wrap.
dataset : CatalogDataset
The parent dataset.
"""
def __init__(self, dataset_group: DatasetGroup, dataset):
self.name = dataset_group.name
self.long_name = dataset_group.long_name or ""
self.description = dataset_group.description or ""
self.dataset = dataset
def __str__(self):
"""Return the group name as its string representation.
Returns
-------
str
The short name of the group.
"""
return self.name
@property
async def files(self):
"""Return the list of files in this group.
Returns
-------
list
List of file objects in this group.
"""
return []
async def _fetch_files(self):
"""Fetch files from the remote source for this group."""
return []
[docs]
async def search(self, **kwargs):
"""Search for files within this group matching the given criteria.
Parameters
----------
``**kwargs``
Arbitrary filter criteria.
Returns
-------
list
List of matching file objects.
"""
return []
[docs]
class DuckLakeCredentials(BaseModel):
"""Credentials for authenticating with the S3-compatible object storage.
Parameters
----------
access_key : SecretStr
The S3 access key ID.
secret_key : SecretStr
The S3 secret access key.
"""
access_key: SecretStr
secret_key: SecretStr
[docs]
class DuckLake(BaseRemoteClient):
"""Client for the DuckLake S3-based public health dataset catalog.
Parameters
----------
endpoint : str, optional
S3-compatible object storage endpoint.
region : str, optional
Storage region name.
bucket : str, optional
Bucket name containing the catalog.
credentials : DuckLakeCredentials, optional
Credentials for authenticated S3 operations.
engine : object, optional
Pre-configured SQLAlchemy engine to reuse.
"""
endpoint: str = "nbg1.your-objectstorage.com"
region: str = "nbg1"
bucket: str = "pysus"
credentials: DuckLakeCredentials | None = None
_cache_dir: Path = PrivateAttr()
_catalog_local: Path = PrivateAttr()
_catalog_remote: str = "public/catalog.db"
_s3_client: Any = PrivateAttr(default=None)
_engine: Any = PrivateAttr(default=None)
_Session: Any = PrivateAttr(default=None)
def __init__(self, engine=None, **data):
"""Initialize the DuckLake client with an optional existing engine.
Parameters
----------
engine : object, optional
Pre-configured SQLAlchemy engine instead of creating a new one.
``**data``
Additional fields passed to the Pydantic base model.
"""
super().__init__(**data)
self._engine = engine
self._cache_dir = Path(CACHEPATH) / "ducklake"
self._cache_dir.mkdir(parents=True, exist_ok=True)
self._catalog_local = self._cache_dir / "catalog.db"
@property
def name(self) -> str:
"""Return the short name of this client.
Returns
-------
str
The client short name.
"""
return "DuckLake"
@property
def long_name(self) -> str:
"""Return the human-readable name of this client.
Returns
-------
str
The client display name.
"""
return "PySUS s3 Client"
@property
def description(self) -> str:
"""Return a description of this client.
Returns
-------
str
A description string (currently empty).
"""
return "" # TODO:
@property
def catalog_path(self) -> Path:
"""Return the local path to the downloaded catalog database.
Returns
-------
Path
Filesystem path to the local catalog database file.
"""
return self._catalog_local
@property
def _catalog_url(self) -> str:
"""Return the remote URL of the catalog database file."""
return f"https://{self.endpoint}/{self.bucket}/{self._catalog_remote}"
@property
def _is_authenticated(self) -> bool:
"""Return whether the client has credentials configured."""
return self.credentials is not None
[docs]
async def datasets(self, **kwargs) -> list[DuckDataset]:
"""Return all datasets from the catalog as DuckDataset instances.
Parameters
----------
``**kwargs``
Additional filter arguments (currently unused).
Returns
-------
list[DuckDataset]
List of all datasets in the catalog.
"""
if not self._Session:
await self.connect()
def _fetch():
with self._Session() as session:
results = (
session.query(CatalogDataset)
.options(
joinedload(CatalogDataset.groups).joinedload(
DatasetGroup.files
),
joinedload(CatalogDataset.files),
)
.all()
)
session.expunge_all()
return results
records = await to_thread.run_sync(_fetch)
return [DuckDataset(record=rec, client=self) for rec in records]
[docs]
async def login(
self,
access_key: str | None = None,
secret_key: str | None = None,
**kwargs,
) -> None:
"""Authenticate with S3 credentials and reconnect to the catalog.
Parameters
----------
access_key : str, optional
S3 access key ID. If omitted, credentials are cleared.
secret_key : str, optional
S3 secret access key. If omitted, credentials are cleared.
``**kwargs``
Additional arguments (currently unused).
"""
if access_key and secret_key:
self.credentials = DuckLakeCredentials(
access_key=SecretStr(access_key),
secret_key=SecretStr(secret_key),
)
else:
self.credentials = None
await self.connect(force=True)
if self._is_authenticated:
self._s3_client = await to_thread.run_sync(
self._get_s3_client,
)
def _setup_engine(self):
"""Create and configure the DuckDB engine with S3 settings."""
engine = create_engine(
f"duckdb:///{self._catalog_local}",
poolclass=StaticPool,
)
with engine.connect() as conn:
conn.exec_driver_sql("INSTALL ducklake; LOAD ducklake;")
has_pysus = conn.exec_driver_sql(
"""
SELECT 1 FROM information_schema.schemata WHERE
schema_name = 'pysus'
"""
).fetchone()
if has_pysus:
conn.exec_driver_sql("SET search_path='pysus,main';")
else:
conn.exec_driver_sql("SET search_path='main';")
s3_cfg = {
"s3_endpoint": self.endpoint,
"s3_region": self.region,
"s3_url_style": "path",
"s3_use_ssl": "true",
}
if self.credentials and self._is_authenticated:
s3_cfg["s3_access_key_id"] = (
self.credentials.access_key.get_secret_value()
)
s3_cfg["s3_secret_access_key"] = (
self.credentials.secret_key.get_secret_value()
)
for key, value in s3_cfg.items():
conn.exec_driver_sql(f"SET {key}='{value}'")
conn.commit()
return engine
[docs]
async def connect(self, force: bool = False):
"""Connect to the catalog, downloading it first if necessary.
Parameters
----------
force : bool, optional
Whether to re-download and re-connect even if already connected.
"""
if self._engine and not force:
if not self._Session:
self._Session = sessionmaker(bind=self._engine)
return
await self._load_catalog()
self._engine = await to_thread.run_sync(self._setup_engine)
self._Session = sessionmaker(bind=self._engine)
[docs]
async def close(self):
"""Dispose the engine, then upload the catalog if authenticated.
Raises
------
PermissionError
If the client is not authenticated but an upload is required.
"""
if self._engine:
await to_thread.run_sync(self._engine.dispose)
self._engine = None
self._Session = None
if self._is_authenticated:
await self._upload_catalog()
self._s3_client = None
async def _download_file(
self,
file: BaseRemoteFile,
output: Path,
callback: Callable[[int, int], None] | None = None,
) -> Path:
"""Download a single file from object storage to the local path."""
if not isinstance(file, File):
raise ValueError("FTP File was not properly instantiated")
url = f"https://{self.endpoint}/{self.bucket}/{file.record.path}"
async with httpx.AsyncClient(follow_redirects=True) as client:
async with client.stream("GET", url) as r:
r.raise_for_status()
total = int(r.headers.get("Content-Length", 0))
downloaded = 0
with open(output, "wb") as f:
async for chunk in r.aiter_bytes(chunk_size=1024 * 1024):
await to_thread.run_sync(f.write, chunk)
downloaded += len(chunk)
if callback:
callback(downloaded, total)
return output
async def _download_catalog(self, client: httpx.AsyncClient):
"""Download the catalog database from remote storage with retries."""
max_retries = 5
for attempt in range(max_retries):
try:
async with client.stream("GET", self._catalog_url) as r:
r.raise_for_status()
with open(self._catalog_local, "wb") as f:
async for chunk in r.aiter_bytes(
chunk_size=1024 * 1024,
):
await to_thread.run_sync(f.write, chunk)
return
except OSError as e:
if attempt < max_retries - 1:
await sleep(1)
else:
raise e
def _get_s3_client(self):
"""Create and return a boto3 S3 client for the configured endpoint."""
if not self.credentials:
raise ConnectionError("S3 Credentials not found")
return boto3.client(
"s3",
endpoint_url=f"https://{self.endpoint}",
aws_access_key_id=self.credentials.access_key.get_secret_value(),
aws_secret_access_key=(
self.credentials.secret_key.get_secret_value()
),
region_name=self.region,
config=Config(signature_version="s3v4"),
)
async def _load_catalog(self):
"""Download remote catalog if the local copy is outdated or missing."""
async with httpx.AsyncClient(follow_redirects=True) as client:
local_size = -1
if self._catalog_local.exists():
try:
local_size = self._catalog_local.stat().st_size
except OSError:
pass
try:
head = await client.head(self._catalog_url)
head.raise_for_status()
remote_size = int(head.headers.get("content-length", 0))
except Exception: # noqa: B902
remote_size = 0
if remote_size != local_size:
await self._download_catalog(client)
async def _upload_catalog(self):
"""Upload the local catalog database to remote storage."""
if not self._is_authenticated:
raise PermissionError(
"Admin credentials required to upload catalog.",
)
def _upload():
self._s3_client.upload_file(
str(self._catalog_local),
self.bucket,
self._catalog_remote,
)
await to_thread.run_sync(_upload)
[docs]
async def query(
self,
client: Literal["FTP", "DadosGov"] | None = None,
dataset: str | None = None,
group: str | None = None,
state: str | None = None,
year: int | None = None,
month: int | None = None,
) -> list[File]:
"""Filter catalog files by client, dataset, group, state, year.
Parameters
----------
client : Literal["FTP", "DadosGov"], optional
Source client to filter by.
dataset : str, optional
Dataset name to filter by.
group : str, optional
Group name pattern to filter by (case-insensitive ILIKE).
state : str, optional
Two-letter state code to filter by.
year : int, optional
Year to filter by.
month : int, optional
Month to filter by.
Returns
-------
list[:class:`~pysus.api.ducklake.models.File`]
List of matching file objects.
"""
if not self._Session:
await self.connect()
def _query():
with self._Session() as session:
q = session.query(CatalogFile)
if dataset:
q = (
q.join(CatalogFile.dataset)
.options(contains_eager(CatalogFile.dataset))
.filter(CatalogDataset.name == dataset.lower())
)
else:
q = q.options(joinedload(CatalogFile.dataset))
if group:
q = (
q.join(CatalogFile.group)
.options(contains_eager(CatalogFile.group))
.filter(DatasetGroup.name.ilike(group))
)
else:
q = q.options(joinedload(CatalogFile.group))
if state:
q = q.filter(CatalogFile.state == state.upper())
if year:
q = q.filter(CatalogFile.year == year)
if month:
q = q.filter(CatalogFile.month == month)
results = q.all()
session.expunge_all()
return results
records = await to_thread.run_sync(_query)
if client:
prefix = f"public/data/{client.lower()}/"
records = [r for r in records if r.path.startswith(prefix)]
else:
ftp = [r for r in records if r.path.startswith("public/data/ftp/")]
dadosgov = [
r for r in records if r.path.startswith("public/data/dadosgov/")
]
ftp_keys = set()
for r in ftp:
stem = Path(r.path).stem
key = (r.dataset_id, r.year, r.month, stem)
ftp_keys.add(key)
def has_ftp_match(r):
stem = Path(r.path).stem
if stem.endswith(".csv"):
stem = stem[:-4]
key = (r.dataset_id, r.year, r.month, stem)
return key in ftp_keys
records = ftp + [r for r in dadosgov if not has_ftp_match(r)]
return [
File(
path=r.path,
record=r,
dataset=CatalogDatasetAdapter(r.dataset, self),
group=(
DatasetGroupAdapter(r.group, r.dataset) if r.group else None
),
)
for r in records
]