obs-portal/api/tools/process_track.py

239 lines
7.5 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import argparse
import logging
import os
import tempfile
import json
import shutil
import asyncio
import hashlib
import struct
import pytz
from obs.face.importer import ImportMeasurementsCsv
from obs.face.geojson import ExportMeasurements
from obs.face.annotate import AnnotateMeasurements
from obs.face.filter import (
AnonymizationMode,
ChainFilter,
ConfirmedFilter,
DistanceMeasuredFilter,
PrivacyFilter,
PrivacyZone,
PrivacyZonesFilter,
RequiredFieldsFilter,
)
from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource
from sqlalchemy import delete
from obs.face.db import make_session, connect_db, OvertakingEvent, async_session
log = logging.getLogger(__name__)
async def main():
logging.basicConfig(level=logging.DEBUG, format="%(levelname)s: %(message)s")
parser = argparse.ArgumentParser(
description="processes a single track for use in the portal, "
"using the obs.face algorithms"
)
parser.add_argument(
"-i", "--input", required=True, action="store", help="path to input CSV file"
)
parser.add_argument(
"-o", "--output", required=True, action="store", help="path to output directory"
)
parser.add_argument(
"--path-cache",
action="store",
default=None,
dest="cache_dir",
help="path where the visualization data will be stored",
)
parser.add_argument(
"--settings",
dest="settings_file",
required=True,
default=None,
help="path to track settings file",
)
# https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
postgres_url_default = os.environ.get("POSTGRES_URL")
parser.add_argument(
"--postgres-url",
required=not postgres_url_default,
action="store",
help="connection string for postgres database",
default=postgres_url_default,
)
args = parser.parse_args()
async with connect_db(args.postgres_url):
if args.cache_dir is None:
with tempfile.TemporaryDirectory() as cache_dir:
args.cache_dir = cache_dir
await process(args)
else:
await process(args)
async def process(args):
log.info("Loading OpenStreetMap data")
tile_source = DatabaseTileSource(async_session.get())
# tile_source = OverpassTileSource(args.cache_dir)
data_source = DataSource(tile_source)
filename_input = os.path.abspath(args.input)
dataset_id = os.path.splitext(os.path.basename(args.input))[0]
os.makedirs(args.output, exist_ok=True)
log.info("Loading settings")
settings_path = os.path.abspath(args.settings_file)
with open(settings_path, "rt") as f:
settings = json.load(f)
settings_output_path = os.path.abspath(
os.path.join(args.output, "track-settings.json")
)
if settings_path != settings_output_path:
log.info("Copy settings to output directory")
shutil.copyfile(settings_path, settings_output_path)
log.info("Annotating and filtering CSV file")
imported_data, statistics = ImportMeasurementsCsv().read(
filename_input,
user_id="dummy",
dataset_id=dataset_id,
)
input_data = await AnnotateMeasurements(
data_source, cache_dir=args.cache_dir
).annotate(imported_data)
filters_from_settings = []
for filter_description in settings.get("filters", []):
filter_type = filter_description.get("type")
if filter_type == "PrivacyZonesFilter":
privacy_zones = [
PrivacyZone(
latitude=zone.get("latitude"),
longitude=zone.get("longitude"),
radius=zone.get("radius"),
)
for zone in filter_description.get("config", {}).get("privacyZones", [])
]
filters_from_settings.append(PrivacyZonesFilter(privacy_zones))
else:
log.warning("Ignoring unknown filter type %r in settings file", filter_type)
track_filter = ChainFilter(
RequiredFieldsFilter(),
PrivacyFilter(
user_id_mode=AnonymizationMode.REMOVE,
measurement_id_mode=AnonymizationMode.REMOVE,
),
*filters_from_settings,
)
measurements_filter = DistanceMeasuredFilter()
overtaking_events_filter = ConfirmedFilter()
track_points = track_filter.filter(input_data, log=log)
measurements = measurements_filter.filter(track_points, log=log)
overtaking_events = overtaking_events_filter.filter(measurements, log=log)
exporter = ExportMeasurements("measurements.dummy")
await exporter.add_measurements(measurements)
measurements_json = exporter.get_data()
del exporter
exporter = ExportMeasurements("overtaking_events.dummy")
await exporter.add_measurements(overtaking_events)
overtaking_events_json = exporter.get_data()
del exporter
track_json = {
"type": "Feature",
"geometry": {
"type": "LineString",
"coordinates": [[m["latitude"], m["longitude"]] for m in track_points],
},
}
statistics_json = {
"recordedAt": statistics["t_min"].isoformat()
if statistics["t_min"] is not None
else None,
"recordedUntil": statistics["t_max"].isoformat()
if statistics["t_max"] is not None
else None,
"duration": statistics["t"],
"length": statistics["d"],
"segments": statistics["n_segments"],
"numEvents": statistics["n_confirmed"],
"numMeasurements": statistics["n_measurements"],
"numValid": statistics["n_valid"],
}
for output_filename, data in [
("measurements.json", measurements_json),
("overtakingEvents.json", overtaking_events_json),
("track.json", track_json),
("statistics.json", statistics_json),
]:
with open(os.path.join(args.output, output_filename), "w") as fp:
json.dump(data, fp, indent=4)
log.info("Importing to database.")
async with make_session() as session:
await clear_track_data(session, settings["trackId"])
await import_overtaking_events(session, settings["trackId"], overtaking_events)
await session.commit()
async def clear_track_data(session, track_id):
await session.execute(
delete(OvertakingEvent).where(OvertakingEvent.track_id == track_id)
)
async def import_overtaking_events(session, track_id, overtaking_events):
event_models = []
for m in overtaking_events:
sha = hashlib.sha256()
sha.update(track_id.encode("utf-8"))
sha.update(struct.pack("Q", int(m["time"].timestamp())))
hex_hash = sha.hexdigest()
event_models.append(
OvertakingEvent(
track_id=track_id,
hex_hash=hex_hash,
way_id=m["OSM_way_id"],
direction_reversed=m["OSM_way_orientation"] < 0,
geometry=json.dumps(
{
"type": "Point",
"coordinates": [m["longitude"], m["latitude"]],
}
),
latitude=m["latitude"],
longitude=m["longitude"],
time=m["time"].astimezone(pytz.utc).replace(tzinfo=None),
distance_overtaker=m["distance_overtaker"],
distance_stationary=m["distance_stationary"],
course=m["course"],
speed=m["speed"],
)
)
session.add_all(event_models)
if __name__ == "__main__":
asyncio.run(main())