obs-portal/api/obs/api/process.py

256 lines
8.2 KiB
Python

import logging
import os
import json
import asyncio
import hashlib
import struct
import pytz
from os.path import join
from datetime import datetime
from sqlalchemy import delete, select
from sqlalchemy.orm import joinedload
from obs.face.importer import ImportMeasurementsCsv
from obs.face.geojson import ExportMeasurements
from obs.face.annotate import AnnotateMeasurements
from obs.face.filter import (
AnonymizationMode,
ChainFilter,
ConfirmedFilter,
DistanceMeasuredFilter,
PrivacyFilter,
PrivacyZone,
PrivacyZonesFilter,
RequiredFieldsFilter,
)
from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource
from obs.api.db import OvertakingEvent, Track, make_session
from obs.api.app import app
log = logging.getLogger(__name__)
def get_data_source():
"""
Creates a data source based on the configuration of the portal. In *lean*
mode, the OverpassTileSource is used to fetch data on demand. In normal
mode, the roads database is used.
"""
if app.config.LEAN_MODE:
tile_source = OverpassTileSource(cache_dir=app.config.OBS_FACE_CACHE_DIR)
else:
tile_source = DatabaseTileSource()
return DataSource(tile_source)
async def process_tracks_loop(delay):
while True:
try:
async with make_session() as session:
track = (
await session.execute(
select(Track)
.where(Track.processing_status == "queued")
.order_by(Track.processing_queued_at)
.options(joinedload(Track.author))
)
).scalar()
if track is None:
await asyncio.sleep(delay)
continue
data_source = get_data_source()
await process_track(session, track, data_source)
except BaseException:
log.exception("Failed to process track. Will continue.")
await asyncio.sleep(1)
continue
async def process_tracks(tracks):
"""
Processes the tracks and writes event data to the database.
:param tracks: A list of strings which
"""
data_source = get_data_source()
async with make_session() as session:
for track_id_or_slug in tracks:
track = (
await session.execute(
select(Track)
.where(
Track.id == track_id_or_slug
if isinstance(track_id_or_slug, int)
else Track.slug == track_id_or_slug
)
.options(joinedload(Track.author))
)
).scalar()
if not track:
raise ValueError(f"Track {track_id_or_slug!r} not found.")
await process_track(session, track, data_source)
def to_naive_utc(t):
if t is None:
return None
return t.astimezone(pytz.UTC).replace(tzinfo=None)
async def process_track(session, track, data_source):
try:
track.processing_status = "complete"
track.processed_at = datetime.utcnow()
await session.commit()
original_file_path = track.get_original_file_path(app.config)
output_dir = join(
app.config.PROCESSING_OUTPUT_DIR, track.author.username, track.slug
)
os.makedirs(output_dir, exist_ok=True)
log.info("Annotating and filtering CSV file")
imported_data, statistics = ImportMeasurementsCsv().read(
original_file_path,
user_id="dummy", # TODO: user username or id or nothing?
dataset_id=Track.slug, # TODO: use track id or slug or nothing?
)
annotator = AnnotateMeasurements(
data_source, cache_dir=app.config.OBS_FACE_CACHE_DIR
)
input_data = await annotator.annotate(imported_data)
track_filter = ChainFilter(
RequiredFieldsFilter(),
PrivacyFilter(
user_id_mode=AnonymizationMode.REMOVE,
measurement_id_mode=AnonymizationMode.REMOVE,
),
# TODO: load user privacy zones and create a PrivacyZonesFilter() from them
)
measurements_filter = DistanceMeasuredFilter()
overtaking_events_filter = ConfirmedFilter()
track_points = track_filter.filter(input_data, log=log)
measurements = measurements_filter.filter(track_points, log=log)
overtaking_events = overtaking_events_filter.filter(measurements, log=log)
exporter = ExportMeasurements("measurements.dummy")
await exporter.add_measurements(measurements)
measurements_json = exporter.get_data()
del exporter
exporter = ExportMeasurements("overtaking_events.dummy")
await exporter.add_measurements(overtaking_events)
overtaking_events_json = exporter.get_data()
del exporter
track_json = {
"type": "Feature",
"geometry": {
"type": "LineString",
"coordinates": [[m["longitude"], m["latitude"]] for m in track_points],
},
}
for output_filename, data in [
("measurements.json", measurements_json),
("overtakingEvents.json", overtaking_events_json),
("track.json", track_json),
]:
target = join(output_dir, output_filename)
log.debug("Writing file %s", target)
with open(target, "w") as fp:
json.dump(data, fp, indent=4)
log.info("Clearing old track data...")
await clear_track_data(session, track)
await session.commit()
log.info("Import events into database...")
await import_overtaking_events(session, track, overtaking_events)
log.info("Write track statistics and update status...")
track.recorded_at = to_naive_utc(statistics["t_min"])
track.recorded_until = to_naive_utc(statistics["t_max"])
track.duration = statistics["t"]
track.length = statistics["d"]
track.segments = statistics["n_segments"]
track.num_events = statistics["n_confirmed"]
track.num_measurements = statistics["n_measurements"]
track.num_valid = statistics["n_valid"]
track.processing_status = "complete"
track.processed_at = datetime.utcnow()
await session.commit()
log.info("Track %s imported.", track.slug)
except BaseException as e:
await clear_track_data(session, track)
track.processing_status = "error"
track.processing_log = str(e)
track.processed_at = datetime.utcnow()
await session.commit()
raise
async def clear_track_data(session, track):
track.recorded_at = None
track.recorded_until = None
track.duration = None
track.length = None
track.segments = None
track.num_events = None
track.num_measurements = None
track.num_valid = None
await session.execute(
delete(OvertakingEvent).where(OvertakingEvent.track_id == track.id)
)
async def import_overtaking_events(session, track, overtaking_events):
# We use a dictionary to prevent per-track hash collisions, ignoring all
# but the first event of the same hash
event_models = {}
for m in overtaking_events:
hex_hash = hashlib.sha256(
struct.pack(
"ddQ", m["latitude"], m["longitude"], int(m["time"].timestamp())
)
).hexdigest()
event_models[hex_hash] = OvertakingEvent(
track_id=track.id,
hex_hash=hex_hash,
way_id=m.get("OSM_way_id"),
direction_reversed=m.get("OSM_way_orientation", 0) < 0,
geometry=json.dumps(
{
"type": "Point",
"coordinates": [m["longitude"], m["latitude"]],
}
),
latitude=m["latitude"],
longitude=m["longitude"],
time=m["time"].astimezone(pytz.utc).replace(tzinfo=None),
distance_overtaker=m["distance_overtaker"],
distance_stationary=m["distance_stationary"],
course=m["course"],
speed=m["speed"],
)
session.add_all(event_models.values())