"""Map file extensions and MIME types to their handler classes."""
import asyncio
import csv
import gzip
import shutil
import sys
import tarfile
import zipfile
from collections.abc import AsyncGenerator, Callable
from datetime import datetime
from pathlib import Path
from typing import ClassVar
import chardet
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from anyio import to_thread
from dbfread import DBF as DBFReader
from pydantic import Field, PrivateAttr
from pysus import CACHEPATH
from pysus.api.models import BaseCompressedFile, BaseLocalFile, BaseTabularFile
from .types import FileType
try:
from pyreaddbc import dbc2dbf
DBC_IMPORT = True
except ImportError:
DBC_IMPORT = False
[docs]
class File(BaseLocalFile):
"""Represents a generic local file with no special handling."""
type: FileType = Field("FILE")
[docs]
async def load(self) -> bytes:
"""Read the entire file contents into memory as bytes."""
return await to_thread.run_sync(self.path.read_bytes)
[docs]
async def stream(
self,
chunk_size: int = 1024 * 1024,
) -> AsyncGenerator[bytes, None]:
"""Yield the file contents in chunks of the given size."""
def _read_sync():
"""Read file chunks synchronously in a thread."""
with open(self.path, "rb") as f:
while chunk := f.read(chunk_size):
yield chunk
for chunk in _read_sync():
yield chunk
await asyncio.sleep(0)
[docs]
class Directory(BaseLocalFile):
"""Represents a directory on the local filesystem."""
type: FileType = Field("DIR")
def __repr__(self) -> str:
"""Return the directory name with a trailing slash."""
return f"{self.basename}/"
[docs]
async def load(self) -> list[BaseLocalFile]:
"""Load all entries inside the directory as file objects."""
from pysus.api.extensions import ExtensionFactory
if not self.path.exists():
return []
paths = list(self.path.iterdir())
tasks = [ExtensionFactory.instantiate(p) for p in paths]
return list(await asyncio.gather(*tasks))
[docs]
async def stream(
self,
chunksize: int = 10000,
) -> AsyncGenerator[BaseLocalFile, None]:
"""Yield each entry inside the directory as a file object."""
from pysus.api.extensions import ExtensionFactory
for p in self.path.iterdir():
yield await ExtensionFactory.instantiate(p)
[docs]
class CSV(BaseTabularFile):
"""Represents a CSV file with automatic encoding and separator detection."""
type: FileType = Field("CSV")
_encoding: str | None = PrivateAttr(default=None)
_sep: str | None = PrivateAttr(default=None)
@property
def columns(self) -> list[str]:
"""Return the column names from the CSV header row."""
if self._encoding is not None:
enc = self._encoding
else:
with open(self.path, "rb") as f:
raw = f.read(1024 * 300)
raw_enc = chardet.detect(raw)["encoding"]
enc = (
"iso-8859-1"
if raw_enc is None or raw_enc.lower() == "ascii"
else raw_enc
)
self._encoding = enc
df = pd.read_csv(self.path, sep=",", nrows=0, encoding=enc)
return df.columns.tolist()
@property
def rows(self) -> int:
"""Return the number of data rows in the file."""
count = 0
with open(self.path, "rb") as f:
for _ in f:
count += 1
return max(0, count - 1)
async def _get_encoding(self) -> str:
"""Detect and cache the file's character encoding."""
if self._encoding is None:
def detect():
"""Detect encoding from file bytes synchronously."""
with open(self.path, "rb") as f:
return chardet.detect(f.read(1024 * 300))
result = await to_thread.run_sync(detect)
enc = result["encoding"]
if enc is None or enc.lower() == "ascii":
enc = "iso-8859-1"
self._encoding = enc
return self._encoding
async def _get_sep(self) -> str:
"""Sniff and cache the CSV delimiter."""
if self._sep is None:
encoding = await self._get_encoding()
def sniff():
"""Sniff the CSV delimiter synchronously."""
try:
with open(self.path, encoding=encoding) as f:
sample = f.read(1024 * 10)
dialect = csv.Sniffer().sniff(sample)
return dialect.delimiter
except ValueError:
return ","
self._sep = await to_thread.run_sync(sniff)
return self._sep
[docs]
async def load(self) -> pd.DataFrame:
"""Read the entire CSV into a DataFrame."""
encoding = await self._get_encoding()
separator = await self._get_sep()
def _read_sync():
"""Read the CSV synchronously in a thread."""
return pd.read_csv(
self.path, sep=separator, encoding=encoding, low_memory=False
)
return await to_thread.run_sync(_read_sync)
[docs]
async def stream(
self,
chunk_size: int = 10000,
) -> AsyncGenerator[pd.DataFrame, None]:
"""Yield the CSV in chunks of the given number of rows."""
encoding = await self._get_encoding()
separator = await self._get_sep()
def _get_reader_sync():
"""Create a CSV chunk reader synchronously in a thread."""
return pd.read_csv(
self.path,
sep=separator,
encoding=encoding,
chunksize=chunk_size,
dtype=str,
low_memory=False,
)
reader = await to_thread.run_sync(_get_reader_sync)
for chunk in reader:
yield chunk
await asyncio.sleep(0)
[docs]
class Parquet(BaseTabularFile):
"""Represents a Parquet file with optional date and integer type parsing."""
type: FileType = Field("PARQUET")
add_dv: bool = True
@property
def schema(self) -> pa.Schema:
"""Return the Parquet schema as a PyArrow Schema object."""
return pq.read_schema(self.path)
@property
def columns(self) -> list[str]:
"""Return the column names from the Parquet schema."""
return pq.read_schema(self.path).names
@property
def rows(self) -> int:
"""Return the number of rows from the Parquet metadata."""
return pq.read_metadata(self.path).num_rows
@staticmethod
def _apply_add_dv(df: pd.DataFrame) -> pd.DataFrame:
"""Apply the IBGE verification digit to geocode columns in-place."""
from pysus.api.utils import add_dv, is_geocode_column
geocode_cols = [c for c in df.columns if is_geocode_column(c)]
for col in geocode_cols:
df[col] = df[col].astype(str).apply(add_dv)
return df
[docs]
async def load(self, parse: bool = True) -> pd.DataFrame:
"""Read the entire Parquet file into a DataFrame."""
def _load():
"""Read the Parquet file synchronously in a thread."""
df = pd.read_parquet(self.path, engine="pyarrow")
if parse:
df = self.parse_dftypes(df)
if self.add_dv:
df = self._apply_add_dv(df)
return df
return await to_thread.run_sync(_load)
[docs]
async def stream(
self, chunk_size: int = 10000, parse: bool = False
) -> AsyncGenerator[pd.DataFrame, None]:
"""Yield the Parquet file in batches of the given size."""
parquet_file = await to_thread.run_sync(pq.ParquetFile, self.path)
if parquet_file.metadata.num_row_groups == 0:
return
for batch in parquet_file.iter_batches(batch_size=chunk_size):
df = batch.to_pandas()
if parse:
df = self.parse_dftypes(df)
if self.add_dv:
df = self._apply_add_dv(df)
yield df
await asyncio.sleep(0)
[docs]
@staticmethod
def parse_dftypes(df: pd.DataFrame) -> pd.DataFrame:
"""Convert known date and integer columns to their proper types."""
def str_to_int(string):
"""Convert a string to int, return original if not possible."""
if pd.isna(string):
return string
clean = str(string).replace(" ", "")
return int(clean) if clean.isnumeric() else string
def str_to_date(string):
"""Convert a date string to date or return original on failure."""
if isinstance(string, str):
try:
return datetime.strptime(string, "%Y%m%d").date()
except ValueError:
return string
return string
cols_to_date = ["DT_NOTIFIC", "DT_SIN_PRI", "DT_NASC", "DT_INTER"]
cols_to_int = ["CODMUNRES", "IDADE"]
for col in df.columns:
if col in cols_to_date:
df[col] = df[col].map(str_to_date)
elif col in cols_to_int:
df[col] = df[col].map(str_to_int)
df = df.replace(r"^\s+$", "", regex=True)
return df.convert_dtypes()
[docs]
class DBF(BaseTabularFile):
"""Represents a dBASE (DBF) file."""
type: FileType = Field("DBF")
@property
def columns(self) -> list[str]:
"""Return the field names from the DBF file."""
return DBFReader(self.path, load=False).field_names
@property
def rows(self) -> int:
"""Return the number of records in the DBF file."""
return len(DBFReader(self.path, load=False))
[docs]
def decode_column(self, value):
"""Decode a raw DBF value, handling byte strings and null bytes.
Parameters
----------
value : bytes or str or Any
The value to decode.
Returns
-------
str or Any
The decoded and stripped string, or the original value if it is
neither bytes nor str.
"""
if isinstance(value, bytes):
return (
value.decode(encoding="cp1252", errors="replace")
.replace("\x00", "")
.strip()
)
if isinstance(value, str):
return value.replace("\x00", "").strip()
return value
[docs]
async def load(self) -> pd.DataFrame:
"""Read the entire DBF file into a DataFrame."""
def _load():
"""Read the DBF file synchronously in a thread."""
dbf = DBFReader(self.path, encoding="cp1252", raw=True)
df = pd.DataFrame(iter(dbf))
return df.map(self.decode_column)
return await to_thread.run_sync(_load)
[docs]
async def stream(
self,
chunk_size: int = 30000,
) -> AsyncGenerator[pd.DataFrame, None]:
"""Yield the DBF records in chunks of the given size."""
def _get_db():
"""Open the DBF reader synchronously in a thread."""
return DBFReader(self.path, encoding="cp1252", raw=True)
dbf_file = await to_thread.run_sync(_get_db)
records = []
for i, record in enumerate(dbf_file):
records.append(record)
if (i + 1) % chunk_size == 0:
df = pd.DataFrame(records).map(self.decode_column)
yield df
records = []
await asyncio.sleep(0)
if records:
yield pd.DataFrame(records).map(self.decode_column)
[docs]
async def to_parquet(
self,
output_path: str | Path | None = None,
chunk_size: int = 30000,
callback: Callable[[int, int], None] | None = None,
) -> "Parquet":
"""Convert the DBF file to Parquet format."""
from pysus.api.extensions import ExtensionFactory
out = (
Path(output_path or self.path.with_suffix(".parquet"))
.expanduser()
.resolve()
)
if out.exists():
file = await ExtensionFactory.instantiate(out)
if not isinstance(file, Parquet):
raise RuntimeError(f"Could not parse {out} to Parquet")
async def _stream_to_single_file():
"""Stream DBF records and write them to a single Parquet file."""
dbf_reader = DBFReader(self.path, encoding="cp1252", raw=True)
total_rows = len(dbf_reader)
writer = None
records = []
try:
for i, record in enumerate(dbf_reader):
records.append(record)
current_count = i + 1
if current_count % chunk_size == 0:
df = pd.DataFrame(records).map(self.decode_column)
table = pa.Table.from_pandas(df)
if writer is None:
writer = pq.ParquetWriter(str(out), table.schema)
writer.write_table(table)
records = []
if callback:
callback(current_count, total_rows)
await asyncio.sleep(0)
if records:
df = pd.DataFrame(records).map(self.decode_column)
table = pa.Table.from_pandas(df)
if writer is None:
writer = pq.ParquetWriter(str(out), table.schema)
writer.write_table(table)
if callback:
callback(total_rows, total_rows)
if writer is None:
df_empty = pd.DataFrame(columns=pd.Index(self.columns))
table_empty = pa.Table.from_pandas(df_empty)
writer = pq.ParquetWriter(str(out), table_empty.schema)
finally:
if writer:
writer.close()
await _stream_to_single_file()
file = await ExtensionFactory.instantiate(out)
if not isinstance(file, Parquet):
raise RuntimeError(f"Could not parse {out} to Parquet")
return file
[docs]
class DBC(BaseTabularFile):
"""Represents a compressed DBC file, convertible to DBF then Parquet."""
type: FileType = Field("DBC")
@property
def columns(self) -> list[str]:
"""Not supported for DBC files. Convert to Parquet first."""
raise NotImplementedError(
"DBC metadata cannot be read directly. Convert to Parquet first."
)
@property
def rows(self) -> int:
"""Not supported for DBC files. Convert to Parquet first."""
raise NotImplementedError(
"DBC metadata cannot be read directly. Convert to Parquet first."
)
[docs]
async def load(self) -> pd.DataFrame:
"""Convert to Parquet and load the result as a DataFrame."""
parquet = await self.to_parquet()
return await parquet.load()
[docs]
async def stream(
self,
chunk_size: int = 10000,
) -> AsyncGenerator[pd.DataFrame, None]:
"""Convert to Parquet and stream its chunks."""
parquet = await self.to_parquet()
async for chunk in parquet.stream(chunk_size=chunk_size):
yield chunk
[docs]
async def to_parquet(
self,
output_path: str | Path | None = None,
chunk_size: int = 30000,
callback: Callable[[int, int], None] | None = None,
) -> "Parquet":
"""Decompress DBC to DBF, then convert to Parquet."""
from pysus.api.extensions import ExtensionFactory
if output_path is None:
output_path = self.path.with_suffix(".parquet")
output_path = Path(output_path).expanduser().resolve()
if output_path.exists():
file = await ExtensionFactory.instantiate(output_path)
if not isinstance(file, Parquet):
raise RuntimeError(f"Could not parse {output_path} to parquet")
return file
tmp_dbf_path = self.path.with_suffix(".dbf")
try:
await to_thread.run_sync(
dbc2dbf,
str(self.path),
str(tmp_dbf_path),
)
dbf_ext = await ExtensionFactory.instantiate(tmp_dbf_path)
if not isinstance(dbf_ext, BaseTabularFile):
raise RuntimeError(f"Not a DBF: {dbf_ext}")
return await dbf_ext.to_parquet(
output_path=output_path,
chunk_size=chunk_size,
callback=callback,
)
finally:
if tmp_dbf_path.exists():
await to_thread.run_sync(tmp_dbf_path.unlink)
[docs]
class JSON(BaseTabularFile):
"""Represents a JSON file with tabular data."""
type: FileType = Field("JSON")
@property
def columns(self) -> list[str]:
"""Return the column names from the JSON file."""
df = (
pd.read_json(self.path, nrows=0)
if self.path.stat().st_size > 0
else pd.DataFrame()
)
return df.columns.tolist()
@property
def rows(self) -> int:
"""Return the number of rows in the JSON file."""
return len(pd.read_json(self.path))
[docs]
async def load(self) -> pd.DataFrame:
"""Read the entire JSON file into a DataFrame."""
return await to_thread.run_sync(pd.read_json, self.path)
[docs]
async def stream(
self,
chunk_size: int = 10000,
) -> AsyncGenerator[pd.DataFrame, None]:
"""Yield the entire JSON file as a single DataFrame."""
yield await self.load()
[docs]
class PDF(BaseLocalFile):
"""Represents a PDF file."""
type: FileType = Field("PDF")
[docs]
async def load(self) -> bytes:
"""Read the entire PDF file contents into memory as bytes."""
return await to_thread.run_sync(self.path.read_bytes)
[docs]
async def stream(
self, chunk_size: int | None = None
) -> AsyncGenerator[bytes, None]:
"""Yield the PDF file contents in chunks of the given size."""
def _read():
"""Read PDF file data synchronously."""
with open(self.path, "rb") as f:
if chunk_size:
while chunk := f.read(chunk_size):
yield chunk
else:
yield f.read()
for chunk in _read():
yield chunk
await asyncio.sleep(0)
[docs]
class Zip(BaseCompressedFile):
"""Represents a ZIP archive file."""
type: FileType = Field("ZIP")
[docs]
async def load(self) -> zipfile.ZipFile:
"""Open and return the ZIP archive."""
return await to_thread.run_sync(zipfile.ZipFile, self.path)
[docs]
async def list_members(self) -> list[str]:
"""Return the list of member names inside the archive."""
def _list():
"""List ZIP members synchronously in a thread."""
with zipfile.ZipFile(self.path) as z:
return z.namelist()
return await to_thread.run_sync(_list)
[docs]
async def open_member(self, member_name: str) -> bytes:
"""Read and return the contents of a named archive member."""
def _read():
"""Read a ZIP member synchronously in a thread."""
with zipfile.ZipFile(self.path) as z:
return z.read(member_name)
return await to_thread.run_sync(_read)
[docs]
async def to_parquet(
self,
output_path: str | Path | None = None,
chunk_size: int = 30000,
callback: Callable[[int, int], None] | None = None,
) -> "Parquet":
"""Extract the archive and convert the first tabular file to Parquet."""
final_output = (
Path(output_path or self.path.with_suffix(".parquet"))
.expanduser()
.resolve()
)
temp_dir = self.path.with_suffix(".tmp_extract")
try:
extracted_files = await self.extract(target_dir=temp_dir)
tabular_file = next(
(f for f in extracted_files if isinstance(f, BaseTabularFile)),
None,
)
if not tabular_file:
raise ValueError(
f"No tabular file found inside {self.path.name}",
)
return await tabular_file.to_parquet(
output_path=final_output,
chunk_size=chunk_size,
callback=callback,
)
finally:
await self._safe_cleanup(temp_dir)
async def _safe_cleanup(self, directory: Path):
"""Remove a temporary directory and its contents."""
def _cleanup():
"""Remove directory contents synchronously in a thread."""
if not directory.exists():
return
for item in directory.iterdir():
if item.is_file():
item.unlink()
elif item.is_dir():
for subitem in item.iterdir():
if subitem.is_file():
subitem.unlink()
item.rmdir()
if directory.exists():
directory.rmdir()
await to_thread.run_sync(_cleanup)
[docs]
class GZip(BaseCompressedFile):
"""Represents a GZip-compressed file."""
type: FileType = Field("ZIP")
[docs]
async def load(self) -> bytes:
"""Decompress and read the entire file contents into memory."""
def _read():
"""Decompress and read synchronously in a thread."""
with gzip.open(self.path, "rb") as f:
return f.read()
return await to_thread.run_sync(_read)
[docs]
async def list_members(self) -> list[str]:
"""Return a list containing the single decompressed file name."""
return [self.path.stem]
[docs]
async def open_member(self, member_name: str) -> bytes:
"""Read and return the decompressed file contents."""
return await self.load()
[docs]
class Tar(BaseCompressedFile):
"""Represents a Tar archive file."""
type: FileType = Field("ZIP")
[docs]
async def load(self) -> tarfile.TarFile:
"""Open and return the tar archive."""
return await to_thread.run_sync(tarfile.open, self.path)
[docs]
async def list_members(self) -> list[str]:
"""Return the list of member names inside the archive."""
def _list():
"""List Tar members synchronously in a thread."""
with tarfile.open(self.path) as t:
return t.getnames()
return await to_thread.run_sync(_list)
[docs]
async def open_member(self, member_name: str) -> bytes:
"""Read and return the contents of a named archive member."""
def _read():
"""Read a Tar member synchronously in a thread."""
with tarfile.open(self.path) as t:
f = t.extractfile(member_name)
return f.read() if f else b""
return await to_thread.run_sync(_read)
[docs]
class DBCNotImported(BaseTabularFile):
"""Placeholder for DBC files when optional dependency is not installed."""
path: Path = Field(default_factory=lambda: Path("..."))
type: str | FileType = Field(default="remote")
import_err: ClassVar[
str
] = """
run "pip install pysus[dbc]" to handle DBC files.
Make sure you also have libffi installed on the system. It may not work
on Windows
"""
@property
def name(self) -> str:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
@property
def extension(self) -> str:
"""Return the .dbc extension."""
return ".dbc"
@property
def size(self) -> int:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
@property
def modify(self) -> datetime:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
@property
def columns(self) -> list[str]:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
@property
def rows(self) -> int:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
[docs]
async def load(self) -> pd.DataFrame:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
[docs]
def stream(
self,
chunk_size: int = 10000,
) -> AsyncGenerator[pd.DataFrame, None]:
"""Raise ImportError indicating the missing DBC dependency."""
async def _internal_gen():
"""Yield nothing; always raises ImportError."""
raise ImportError(self.import_err)
yield pd.DataFrame()
return _internal_gen()
[docs]
async def to_parquet(
self,
output_path: str | Path | None = None,
chunk_size: int = 10000,
callback: Callable[[int, int], None] | None = None,
) -> Parquet:
"""Raise ImportError indicating the missing DBC dependency."""
raise ImportError(self.import_err)
[docs]
class ExtensionFactory:
"""Factory that maps file extensions and MIME types to handler classes."""
_mime: dict[str, type[BaseLocalFile]] = {
"application/zip": Zip,
"application/x-gzip": GZip,
"application/x-tar": Tar,
"text/csv": CSV,
"application/pdf": PDF,
"application/json": JSON,
}
_extensions: dict[str, type[BaseLocalFile]] = {
".zip": Zip,
".gz": GZip,
".tar": Tar,
".tgz": Tar,
".tar.gz": Tar,
".csv": CSV,
".parquet": Parquet,
".dbf": DBF,
".dbc": DBC if DBC_IMPORT else DBCNotImported, # type: ignore
".pdf": PDF,
".json": JSON,
}
_magic_available: bool = sys.platform != "win32"
@classmethod
async def _identify(cls, path: Path) -> type[BaseLocalFile] | None:
"""Identify the file class by its MIME type."""
if not cls._magic_available:
return None
try:
import magic
except (ImportError, OSError):
cls._magic_available = False
return None
try:
mime = await to_thread.run_sync(
magic.from_file,
str(path),
True,
)
return cls._mime.get(mime)
except (magic.MagicException, OSError):
return None
[docs]
@classmethod
async def get_file_class(cls, path: Path) -> type[BaseLocalFile]:
"""Return the file handler class for a given path.
First attempts MIME-type identification; falls back to extension
matching.
Parameters
----------
path : Path
The file path to classify.
Returns
-------
type[BaseLocalFile]
The handler class for the file type.
"""
mime_class = await cls._identify(path)
if mime_class:
return mime_class
extension = "".join(path.suffixes).lower()
if extension in cls._extensions:
return cls._extensions[extension]
return cls._extensions.get(path.suffix.lower(), File)
[docs]
@classmethod
async def instantiate(cls, path: str | Path) -> BaseLocalFile:
"""Create and return the appropriate file handler for a path.
Determines whether the path is a directory or a file, resolves the
handler class, and instantiates it.
Parameters
----------
path : str or Path
The filesystem path to wrap in a handler.
Returns
-------
BaseLocalFile
The instantiated file handler.
"""
path = Path(path).expanduser().resolve()
if await to_thread.run_sync(path.is_dir):
return Directory(path=path, type="DIR")
FileClass = await cls.get_file_class(path)
file_type = getattr(FileClass, "type", "FILE")
if not isinstance(file_type, str):
file_type = "FILE"
return FileClass(path=path, type=file_type)