gtfs-dagster/user_code/assets/gtfs_static.py

333 lines
11 KiB
Python

from dagster import (
asset,
AssetExecutionContext,
Output,
MetadataValue,
AutomationCondition,
DynamicPartitionsDefinition,
Config
)
from dagster_duckdb import DuckDBResource
from resources import MobilityDatabaseAPI
import json
import requests
from pathlib import Path
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class GTFSDownloadConfig(Config):
provider: str
producer_url: str
@asset(
deps=["agency_list"],
group_name="gtfs_metadata",
automation_condition=AutomationCondition.eager()
)
def gtfs_feed_metadata(
context: AssetExecutionContext,
duckdb: DuckDBResource,
mobility_db: MobilityDatabaseAPI
) -> Output[None]:
"""
Fetch GTFS feed metadata from Mobility Database API for all agencies
and store in DuckDB.
"""
with duckdb.get_connection() as conn:
# Create the metadata table if it doesn't exist
conn.execute("""
CREATE TABLE IF NOT EXISTS gtfs_feed_metadata (
feed_id VARCHAR PRIMARY KEY,
provider VARCHAR,
status VARCHAR,
official BOOLEAN,
producer_url VARCHAR,
authentication_type INTEGER,
authentication_info_url VARCHAR,
api_key_parameter_name VARCHAR,
license_url VARCHAR,
feed_contact_email VARCHAR,
raw_json JSON,
fetched_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create download history table for static GTFS
conn.execute("""
CREATE TABLE IF NOT EXISTS gtfs_download_history (
feed_id VARCHAR,
download_url VARCHAR,
last_modified TIMESTAMP,
file_path VARCHAR,
file_size_bytes BIGINT,
downloaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (feed_id, last_modified)
)
""")
# Get all GTFS feed IDs from agency_list
feed_ids = conn.execute("""
SELECT DISTINCT GTFS as feed_id
FROM agency_list
WHERE GTFS IS NOT NULL AND GTFS != ''
""").fetchall()
context.log.info(f"Found {len(feed_ids)} feeds to fetch")
successful = 0
failed = 0
for (feed_id,) in feed_ids:
try:
feed_info = mobility_db.get_feed_info(feed_id)
# Extract relevant fields
source_info = feed_info.get("source_info", {})
# Insert or update the record
conn.execute("""
INSERT OR REPLACE INTO gtfs_feed_metadata (
feed_id,
provider,
status,
official,
producer_url,
authentication_type,
authentication_info_url,
api_key_parameter_name,
license_url,
feed_contact_email,
raw_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", [
feed_id,
feed_info.get("provider"),
feed_info.get("status"),
feed_info.get("official"),
source_info.get("producer_url"),
source_info.get("authentication_type"),
source_info.get("authentication_info_url"),
source_info.get("api_key_parameter_name"),
source_info.get("license_url"),
feed_info.get("feed_contact_email"),
json.dumps(feed_info)
])
context.log.info(f"✓ Fetched and stored metadata for {feed_id}")
successful += 1
except Exception as e:
context.log.error(f"✗ Failed to fetch {feed_id}: {e}")
failed += 1
# Get summary stats
total_records = conn.execute(
"SELECT COUNT(*) FROM gtfs_feed_metadata"
).fetchone()[0]
# Get preview for metadata
preview_df = conn.execute("""
SELECT feed_id, provider, status, producer_url
FROM gtfs_feed_metadata
LIMIT 5
""").df()
return Output(
None,
metadata={
"total_feeds": len(feed_ids),
"successful": successful,
"failed": failed,
"total_records_in_db": total_records,
"preview": MetadataValue.md(preview_df.to_markdown(index=False))
}
)
# Dynamic partition definition for GTFS feeds
gtfs_feeds_partitions_def = DynamicPartitionsDefinition(name="gtfs_feeds")
@asset(
deps=["gtfs_feed_metadata"],
group_name="gtfs_metadata",
automation_condition=AutomationCondition.eager()
)
def gtfs_feed_partitions(
context: AssetExecutionContext,
duckdb: DuckDBResource,
) -> Output[None]:
"""
Update the dynamic partitions based on feeds in gtfs_feed_metadata table.
Creates one partition per feed_id.
"""
with duckdb.get_connection() as conn:
feed_ids = conn.execute("""
SELECT feed_id
FROM gtfs_feed_metadata
WHERE producer_url IS NOT NULL AND producer_url != ''
ORDER BY feed_id
""").fetchall()
feed_id_list = [feed_id for (feed_id,) in feed_ids]
# Update the dynamic partitions
context.instance.add_dynamic_partitions(
partitions_def_name="gtfs_feeds",
partition_keys=feed_id_list
)
context.log.info(f"Updated partitions with {len(feed_id_list)} feeds")
return Output(
None,
metadata={
"feed_count": len(feed_id_list),
"feeds": MetadataValue.md("\n".join(f"- {f}" for f in feed_id_list[:20]))
}
)
@asset(
partitions_def=gtfs_feeds_partitions_def,
deps=["gtfs_feed_partitions"],
group_name="gtfs_downloads",
)
def gtfs_feed_downloads(
context: AssetExecutionContext,
config: GTFSDownloadConfig,
duckdb: DuckDBResource,
) -> Output[None]:
"""
Download GTFS feed for each agency partition.
Only downloads if there's a new version available based on Last-Modified header.
Files are saved to data/raw/gtfs/<feed_id>/<yyyy-mm-dd-gtfs.zip>
Runs on the hour and whenever new partitions are added.
"""
feed_id = context.partition_key
download_url = config.producer_url
provider = config.provider
with duckdb.get_connection() as conn:
if not download_url:
context.log.warning(f"No download URL for {feed_id}")
return Output(None, metadata={"status": "no_url"})
# Check the Last-Modified header without downloading the full file
try:
head_response = requests.head(
download_url,
timeout=30,
allow_redirects=True
)
head_response.raise_for_status()
last_modified_str = head_response.headers.get('Last-Modified')
if last_modified_str:
last_modified = datetime.strptime(
last_modified_str,
'%a, %d %b %Y %H:%M:%S GMT'
)
else:
# If no Last-Modified header, use current time
last_modified = datetime.now()
context.log.warning(f"No Last-Modified header for {feed_id}, using current time")
except Exception as e:
context.log.error(f"Failed to check headers for {feed_id}: {e}")
return Output(None, metadata={"status": "error", "error": str(e)})
# Check if we've already downloaded this version
existing = conn.execute("""
SELECT file_path, downloaded_at
FROM gtfs_download_history
WHERE feed_id = ? AND last_modified = ?
ORDER BY downloaded_at DESC
LIMIT 1
""", [feed_id, last_modified]).fetchone()
if existing:
file_path, downloaded_at = existing
context.log.info(
f"Already have latest version of {feed_id} "
f"(modified: {last_modified}, downloaded: {downloaded_at})"
)
return Output(
None,
metadata={
"status": "up_to_date",
"last_modified": last_modified.isoformat(),
"existing_file": file_path,
"downloaded_at": downloaded_at.isoformat()
}
)
# Download the file
context.log.info(f"Downloading new version of {feed_id} (modified: {last_modified})")
# Create directory structure: data/raw/gtfs/<feed_id>/
feed_dir = Path(f"data/raw/gtfs/{feed_id}")
feed_dir.mkdir(parents=True, exist_ok=True)
# Filename format: yyyy-mm-dd-gtfs.zip (using last_modified date)
filename = f"{last_modified.strftime('%Y-%m-%d')}-gtfs.zip"
file_path = feed_dir / filename
# If file exists with same name but different modified time, append time to filename
if file_path.exists():
filename = f"{last_modified.strftime('%Y-%m-%d-%H%M%S')}-gtfs.zip"
file_path = feed_dir / filename
try:
response = requests.get(download_url, timeout=120, stream=True)
response.raise_for_status()
# Write file in chunks to handle large files
with open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
file_size = file_path.stat().st_size
# Record the download in history
conn.execute("""
INSERT INTO gtfs_download_history (
feed_id, download_url, last_modified, file_path, file_size_bytes
) VALUES (?, ?, ?, ?, ?)
""", [feed_id, download_url, last_modified, str(file_path), file_size])
context.log.info(
f"✓ Downloaded {feed_id} to {file_path} ({file_size:,} bytes)"
)
return Output(
None,
metadata={
"status": "downloaded",
"file_path": str(file_path),
"file_size_mb": round(file_size / 1024 / 1024, 2),
"last_modified": last_modified.isoformat(),
"provider": provider,
"download_url": download_url,
}
)
except Exception as e:
context.log.error(f"Failed to download {feed_id}: {e}")
# Clean up partial file if it exists
if file_path.exists():
file_path.unlink()
return Output(
None,
metadata={
"status": "error",
"error": str(e),
"feed_id": feed_id
}
)