"""
Async-compatible client for FFIEC Data Connect.
This module provides async/await support and parallel processing capabilities
while maintaining backward compatibility with the synchronous API.
"""
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from ffiec_data_connect import credentials, ffiec_connection, methods
from ffiec_data_connect.credentials import OAuth2Credentials
[docs]
class RateLimiter:
"""Thread-safe rate limiter for both sync and async use."""
def __init__(self, calls_per_second: float = 10) -> None:
"""Initialize rate limiter.
Args:
calls_per_second: Maximum number of calls per second allowed
"""
self.calls_per_second: float = calls_per_second
self.min_interval: float = 1.0 / calls_per_second
self.last_call: float = 0.0
self.lock = threading.Lock()
[docs]
def wait_if_needed(self) -> None:
"""Synchronous rate limit wait."""
with self.lock:
elapsed = time.time() - self.last_call
if elapsed < self.min_interval:
time.sleep(self.min_interval - elapsed)
self.last_call = time.time()
[docs]
async def async_wait_if_needed(self) -> None:
"""Asynchronous rate limit wait."""
# For async contexts, we need to be careful about blocking
# Use a simple approach to avoid threading issues in async
elapsed = time.time() - self.last_call
if elapsed < self.min_interval:
await asyncio.sleep(self.min_interval - elapsed)
self.last_call = time.time()
[docs]
class AsyncCompatibleClient:
"""Client that supports both sync and async usage patterns.
This client provides:
- Backward compatible synchronous methods
- Parallel processing with thread pools
- Async/await support for integration with async frameworks
- Rate limiting to respect API limits
- Thread-safe operation
"""
def __init__(
self,
credentials: Union[credentials.WebserviceCredentials, "OAuth2Credentials"],
max_concurrent: int = 5,
rate_limit: Optional[float] = 10, # requests per second
executor: Optional[ThreadPoolExecutor] = None,
) -> None:
"""Initialize the async-compatible client.
**ENHANCED**: Now supports both SOAP and REST APIs automatically based on credential type.
For better performance, use OAuth2Credentials for REST API access.
Args:
credentials: Either WebserviceCredentials (SOAP) or OAuth2Credentials (REST)
max_concurrent: Maximum concurrent requests
rate_limit: Maximum requests per second (None to disable)
executor: Optional thread pool executor to use
"""
self.credentials = credentials
self.max_concurrent = max_concurrent
self.rate_limiter = RateLimiter(rate_limit) if rate_limit else None
self.executor = executor or ThreadPoolExecutor(max_workers=max_concurrent)
# Enhanced for dual protocol support
from .credentials import OAuth2Credentials
self._is_rest_client = isinstance(credentials, OAuth2Credentials)
if self._is_rest_client:
# REST clients don't need connection caching
self._connection_cache: Dict[int, ffiec_connection.FFIECConnection] = {}
else:
# SOAP clients use connection caching
self._connection_cache = {}
self._lock = threading.Lock()
self._owned_executor = executor is None # Track if we created the executor
# ===== Synchronous Methods (Backward Compatible) =====
[docs]
def collect_data(
self,
reporting_period: str,
rssd_id: str,
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
) -> Union[List[Dict[str, Any]], Any]:
"""Standard synchronous method - backward compatible.
Args:
reporting_period: Reporting period (e.g., "2020-03-31" or "1Q2020")
rssd_id: RSSD ID of the institution
series: Data series ("call" or "ubpr")
output_type: Output format ("list", "pandas", or "polars")
date_output_format: Date format in output
Returns:
Collected data in requested format
"""
if self.rate_limiter:
self.rate_limiter.wait_if_needed()
# Enhanced for dual protocol support
if self._is_rest_client:
# REST API doesn't need a connection object
return methods.collect_data(
None,
self.credentials,
reporting_period,
rssd_id,
series,
output_type,
date_output_format,
)
else:
# SOAP API uses cached connection
conn = self._get_connection()
return methods.collect_data(
conn,
self.credentials,
reporting_period,
rssd_id,
series,
output_type,
date_output_format,
)
[docs]
def collect_reporting_periods(
self,
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
) -> Union[List[str], Any]:
"""Get available reporting periods - backward compatible.
Args:
series: Data series ("call" or "ubpr")
output_type: Output format ("list", "pandas", or "polars")
date_output_format: Date format in output
Returns:
Available reporting periods
"""
if self.rate_limiter:
self.rate_limiter.wait_if_needed()
# Enhanced for dual protocol support
if self._is_rest_client:
# REST API doesn't need a connection object
return methods.collect_reporting_periods(
None, self.credentials, series, output_type, date_output_format
)
else:
# SOAP API uses cached connection
conn = self._get_connection()
return methods.collect_reporting_periods(
conn, self.credentials, series, output_type, date_output_format
)
[docs]
def collect_data_parallel(
self,
reporting_period: str,
rssd_ids: List[str],
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
progress_callback: Optional[Callable[[str, Any], None]] = None,
) -> Dict[str, Union[List[Dict[str, Any]], Dict[str, Any]]]:
"""Collect data for multiple banks in parallel (sync interface).
Args:
reporting_period: Reporting period
rssd_ids: List of RSSD IDs
series: Data series
output_type: Output format
date_output_format: Date format
progress_callback: Optional callback for progress updates
Returns:
Dictionary mapping RSSD IDs to their data or error info
"""
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
futures = {}
for rssd_id in rssd_ids:
if self.rate_limiter:
self.rate_limiter.wait_if_needed()
future = executor.submit(
self.collect_data,
reporting_period,
rssd_id,
series,
output_type,
date_output_format,
)
futures[future] = rssd_id
results = {}
for future in as_completed(futures):
rssd_id = futures[future]
try:
result = future.result()
results[rssd_id] = result
if progress_callback:
progress_callback(rssd_id, result)
except Exception as e:
results[rssd_id] = {"error": str(e), "rssd_id": rssd_id}
if progress_callback:
progress_callback(rssd_id, {"error": str(e)})
return results
[docs]
def collect_time_series(
self,
rssd_id: str,
reporting_periods: List[str],
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
) -> Dict[str, Union[List[Dict[str, Any]], Dict[str, Any]]]:
"""Collect multiple periods for one bank in parallel (sync interface).
Args:
rssd_id: RSSD ID of the institution
reporting_periods: List of reporting periods
series: Data series
output_type: Output format
date_output_format: Date format
Returns:
Dictionary mapping periods to their data
"""
with ThreadPoolExecutor(
max_workers=min(len(reporting_periods), self.max_concurrent)
) as executor:
futures = {}
for period in reporting_periods:
if self.rate_limiter:
self.rate_limiter.wait_if_needed()
future = executor.submit(
self.collect_data,
period,
rssd_id,
series,
output_type,
date_output_format,
)
futures[future] = period
results = {}
for future in as_completed(futures):
period = futures[future]
try:
results[period] = future.result()
except Exception as e:
results[period] = {"error": str(e), "period": period}
return results
# ===== Async Methods (New Functionality) =====
[docs]
async def collect_data_async(
self,
reporting_period: str,
rssd_id: str,
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
) -> Union[List[Dict[str, Any]], Any]:
"""Async version - runs sync code in thread pool.
Args:
reporting_period: Reporting period
rssd_id: RSSD ID of the institution
series: Data series
output_type: Output format
date_output_format: Date format
Returns:
Collected data in requested format
"""
if self.rate_limiter:
await self.rate_limiter.async_wait_if_needed()
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.collect_data,
reporting_period,
rssd_id,
series,
output_type,
date_output_format,
)
[docs]
async def collect_batch_async(
self,
reporting_period: str,
rssd_ids: List[str],
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
progress_callback: Optional[Callable[[str, Any], None]] = None,
) -> Dict[str, Union[List[Dict[str, Any]], Dict[str, Any]]]:
"""Collect data for multiple banks with rate limiting and progress tracking.
Args:
reporting_period: Reporting period
rssd_ids: List of RSSD IDs
series: Data series
output_type: Output format
date_output_format: Date format
progress_callback: Optional async callback for progress
Returns:
Dictionary mapping RSSD IDs to their data
"""
semaphore = asyncio.Semaphore(self.max_concurrent)
results = {}
async def fetch_one(rssd_id: str) -> Tuple[str, Any]:
async with semaphore:
try:
if self.rate_limiter:
await self.rate_limiter.async_wait_if_needed()
result = await self.collect_data_async(
reporting_period,
rssd_id,
series,
output_type,
date_output_format,
)
if progress_callback:
if asyncio.iscoroutinefunction(progress_callback):
await progress_callback(rssd_id, result) # type: ignore[attr-defined]
else:
progress_callback(rssd_id, result)
return rssd_id, result
except Exception as e:
error_result = {"error": str(e), "rssd_id": rssd_id}
if progress_callback:
if asyncio.iscoroutinefunction(progress_callback):
await progress_callback(rssd_id, error_result) # type: ignore[attr-defined]
else:
progress_callback(rssd_id, error_result)
return rssd_id, error_result
tasks = [fetch_one(rssd_id) for rssd_id in rssd_ids]
for coro in asyncio.as_completed(tasks):
rssd_id, result = await coro
results[rssd_id] = result
return results
[docs]
async def collect_time_series_async(
self,
rssd_id: str,
reporting_periods: List[str],
series: str = "call",
output_type: str = "list",
date_output_format: str = "string_original",
) -> Dict[str, Union[List[Dict[str, Any]], Dict[str, Any]]]:
"""Collect multiple periods for one bank in parallel (async).
Args:
rssd_id: RSSD ID of the institution
reporting_periods: List of reporting periods
series: Data series
output_type: Output format
date_output_format: Date format
Returns:
Dictionary mapping periods to their data
"""
tasks = []
for period in reporting_periods:
task = self.collect_data_async(
period, rssd_id, series, output_type, date_output_format
)
tasks.append((period, task))
results = {}
for period, task in tasks:
try:
results[period] = await task
except Exception as e:
results[period] = {"error": str(e), "period": period}
return results
# ===== Helper Methods =====
def _get_connection(self) -> ffiec_connection.FFIECConnection:
"""Get or create thread-local connection.
Returns:
Thread-local FFIECConnection instance
"""
thread_id = threading.get_ident()
with self._lock:
if thread_id not in self._connection_cache:
self._connection_cache[thread_id] = ffiec_connection.FFIECConnection()
return self._connection_cache[thread_id]
[docs]
def close(self) -> None:
"""Close all connections and cleanup resources."""
with self._lock:
# Close all cached connections
for conn in self._connection_cache.values():
try:
conn.close()
except Exception:
pass
self._connection_cache.clear()
# Shutdown executor if we created it
if self._owned_executor and self.executor:
self.executor.shutdown(wait=True)
# ===== Context Manager Support =====
def __enter__(self) -> "AsyncCompatibleClient":
"""Sync context manager entry."""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Sync context manager exit - cleanup."""
self.close()
async def __aenter__(self) -> "AsyncCompatibleClient":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit - cleanup."""
self.close()