From f2fa806cab725f021998d73483354c4e434a24a8 Mon Sep 17 00:00:00 2001 From: Paul Bienkowski Date: Fri, 29 Oct 2021 11:26:39 +0200 Subject: [PATCH] api: Update to postgis version of scripts --- api/roads_import.lua | 2 +- api/scripts | 2 +- api/tools/db.py | 112 ------------------------------------ api/tools/process_track.py | 63 +++++++++++--------- api/tools/reset_database.py | 2 +- 5 files changed, 38 insertions(+), 143 deletions(-) delete mode 100644 api/tools/db.py diff --git a/api/roads_import.lua b/api/roads_import.lua index 15f106a..8aa9c03 100644 --- a/api/roads_import.lua +++ b/api/roads_import.lua @@ -88,7 +88,7 @@ function osm2pgsql.process_way(object) end roads:add_row({ - geom = { create = 'linear' }, + geometry = { create = 'linear' }, name = tags.name, zone = zone, tags = tags, diff --git a/api/scripts b/api/scripts index 118cc1d..767b716 160000 --- a/api/scripts +++ b/api/scripts @@ -1 +1 @@ -Subproject commit 118cc1d9f9dbd1dd8816a61c0698deaf404cf0ff +Subproject commit 767b7166105b2446c3b4e0334222052114b0ebae diff --git a/api/tools/db.py b/api/tools/db.py deleted file mode 100644 index cf9727a..0000000 --- a/api/tools/db.py +++ /dev/null @@ -1,112 +0,0 @@ -from contextvars import ContextVar -from contextlib import asynccontextmanager - -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.types import UserDefinedType -from sqlalchemy import ( - Column, - String, - Integer, - Boolean, - select, - DateTime, - Float, - Index, - Enum as SqlEnum, - func, -) - - -Base = declarative_base() - - -engine = ContextVar("engine") -async_session = ContextVar("async_session") - - -@asynccontextmanager -async def make_session(): - async with async_session.get()() as session: - yield session - - -async def init_models(): - async with engine.get().begin() as conn: - await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) - - -@asynccontextmanager -async def connect_db(url): - engine_ = create_async_engine(url, echo=False) - t1 = engine.set(engine_) - - async_session_ = sessionmaker(engine_, class_=AsyncSession, expire_on_commit=False) - t2 = async_session.set(async_session_) - - yield - - # for AsyncEngine created in function scope, close and - # clean-up pooled connections - await engine_.dispose() - engine.reset(t1) - async_session.reset(t2) - - - -ZoneType = SqlEnum("rural", "urban", "motorway", name="zone_type") - - - -class Geometry(UserDefinedType): - def get_col_spec(self): - return "GEOMETRY" - - def bind_expression(self, bindvalue): - return func.ST_GeomFromGeoJSON(json.dumps(bindvalue), type_=self) - - def column_expression(self, col): - return json.loads(func.ST_AsGeoJSON(col, type_=self)) - - -class OvertakingEvent(Base): - __tablename__ = "overtaking_event" - __table_args__ = (Index("road_segment", "way_id", "direction_reversed"),) - - id = Column(Integer, autoincrement=True, primary_key=True, index=True) - track_id = Column(String, index=True) - hex_hash = Column(String, unique=True, index=True) - way_id = Column(Integer, index=True) - - # whether we were traveling along the way in reverse direction - direction_reversed = Column(Boolean) - - geometry = Column(Geometry) - latitude = Column(Float) - longitude = Column(Float) - time = Column(DateTime) - distance_overtaker = Column(Float) - distance_stationary = Column(Float) - course = Column(Float) - speed = Column(Float) - - def __repr__(self): - return f"" - - -class Road(Base): - __tablename__ = "road" - way_id = Column(Integer, primary_key=True, index=True) - zone = Column(ZoneType) - name = Column(String) - geometry = Column(Geometry) - - -class RoadSegment(Base): - __tablename__ = "bike_lane" - way_id = Column(Integer, primary_key=True, index=True) - direction_reversed = Column(Boolean) - geometry = Column(Geometry) diff --git a/api/tools/process_track.py b/api/tools/process_track.py index 4f03d72..2d0f1ab 100755 --- a/api/tools/process_track.py +++ b/api/tools/process_track.py @@ -23,10 +23,10 @@ from obs.face.filter import ( PrivacyZonesFilter, RequiredFieldsFilter, ) -from obs.face.osm import DataSource, OverpassTileSource -from sqlalchemy import delete, func, select +from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource +from sqlalchemy import delete -from db import make_session, connect_db, OvertakingEvent +from obs.face.db import make_session, connect_db, OvertakingEvent, async_session log = logging.getLogger(__name__) @@ -64,25 +64,27 @@ async def main(): postgres_url_default = os.environ.get("POSTGRES_URL") parser.add_argument( "--postgres-url", - required=False, + required=not postgres_url_default, action="store", - help="connection string for postgres database, if set, the track result is imported there", + help="connection string for postgres database", default=postgres_url_default, ) args = parser.parse_args() - if args.cache_dir is None: - with tempfile.TemporaryDirectory() as cache_dir: - args.cache_dir = cache_dir + async with connect_db(args.postgres_url): + if args.cache_dir is None: + with tempfile.TemporaryDirectory() as cache_dir: + args.cache_dir = cache_dir + await process(args) + else: await process(args) - else: - await process(args) async def process(args): log.info("Loading OpenStreetMap data") - tile_source = OverpassTileSource(cache_dir=args.cache_dir) + tile_source = DatabaseTileSource(async_session.get()) + # tile_source = OverpassTileSource(args.cache_dir) data_source = DataSource(tile_source) filename_input = os.path.abspath(args.input) @@ -109,9 +111,9 @@ async def process(args): dataset_id=dataset_id, ) - input_data = AnnotateMeasurements(data_source, cache_dir=args.cache_dir).annotate( - imported_data - ) + input_data = await AnnotateMeasurements( + data_source, cache_dir=args.cache_dir + ).annotate(imported_data) filters_from_settings = [] for filter_description in settings.get("filters", []): @@ -145,12 +147,12 @@ async def process(args): overtaking_events = overtaking_events_filter.filter(measurements, log=log) exporter = ExportMeasurements("measurements.dummy") - exporter.add_measurements(measurements) + await exporter.add_measurements(measurements) measurements_json = exporter.get_data() del exporter exporter = ExportMeasurements("overtaking_events.dummy") - exporter.add_measurements(overtaking_events) + await exporter.add_measurements(overtaking_events) overtaking_events_json = exporter.get_data() del exporter @@ -163,8 +165,12 @@ async def process(args): } statistics_json = { - "recordedAt": statistics["t_min"].isoformat(), - "recordedUntil": statistics["t_max"].isoformat(), + "recordedAt": statistics["t_min"].isoformat() + if statistics["t_min"] is not None + else None, + "recordedUntil": statistics["t_max"].isoformat() + if statistics["t_max"] is not None + else None, "duration": statistics["t"], "length": statistics["d"], "segments": statistics["n_segments"], @@ -182,15 +188,11 @@ async def process(args): with open(os.path.join(args.output, output_filename), "w") as fp: json.dump(data, fp, indent=4) - if args.postgres_url: - log.info("Importing to database.") - async with connect_db(args.postgres_url): - async with make_session() as session: - await clear_track_data(session, settings["trackId"]) - await import_overtaking_events( - session, settings["trackId"], overtaking_events - ) - await session.commit() + log.info("Importing to database.") + async with make_session() as session: + await clear_track_data(session, settings["trackId"]) + await import_overtaking_events(session, settings["trackId"], overtaking_events) + await session.commit() async def clear_track_data(session, track_id): @@ -213,7 +215,12 @@ async def import_overtaking_events(session, track_id, overtaking_events): hex_hash=hex_hash, way_id=m["OSM_way_id"], direction_reversed=m["OSM_way_orientation"] < 0, - geometry={"type": "Point", "coordinates": [m["longitude"], m["latitude"]]}, + geometry=json.dumps( + { + "type": "Point", + "coordinates": [m["longitude"], m["latitude"]], + } + ), latitude=m["latitude"], longitude=m["longitude"], time=m["time"].astimezone(pytz.utc).replace(tzinfo=None), diff --git a/api/tools/reset_database.py b/api/tools/reset_database.py index b3c2f7f..1d77328 100755 --- a/api/tools/reset_database.py +++ b/api/tools/reset_database.py @@ -4,7 +4,7 @@ import logging import os import asyncio -from db import init_models, connect_db +from obs.face.db import init_models, connect_db log = logging.getLogger(__name__)