api: Update to postgis version of scripts

This commit is contained in:
Paul Bienkowski 2021-10-29 11:26:39 +02:00
parent 79f3469df8
commit f2fa806cab
5 changed files with 38 additions and 143 deletions

View file

@ -88,7 +88,7 @@ function osm2pgsql.process_way(object)
end end
roads:add_row({ roads:add_row({
geom = { create = 'linear' }, geometry = { create = 'linear' },
name = tags.name, name = tags.name,
zone = zone, zone = zone,
tags = tags, tags = tags,

@ -1 +1 @@
Subproject commit 118cc1d9f9dbd1dd8816a61c0698deaf404cf0ff Subproject commit 767b7166105b2446c3b4e0334222052114b0ebae

View file

@ -1,112 +0,0 @@
from contextvars import ContextVar
from contextlib import asynccontextmanager
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.types import UserDefinedType
from sqlalchemy import (
Column,
String,
Integer,
Boolean,
select,
DateTime,
Float,
Index,
Enum as SqlEnum,
func,
)
Base = declarative_base()
engine = ContextVar("engine")
async_session = ContextVar("async_session")
@asynccontextmanager
async def make_session():
async with async_session.get()() as session:
yield session
async def init_models():
async with engine.get().begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
@asynccontextmanager
async def connect_db(url):
engine_ = create_async_engine(url, echo=False)
t1 = engine.set(engine_)
async_session_ = sessionmaker(engine_, class_=AsyncSession, expire_on_commit=False)
t2 = async_session.set(async_session_)
yield
# for AsyncEngine created in function scope, close and
# clean-up pooled connections
await engine_.dispose()
engine.reset(t1)
async_session.reset(t2)
ZoneType = SqlEnum("rural", "urban", "motorway", name="zone_type")
class Geometry(UserDefinedType):
def get_col_spec(self):
return "GEOMETRY"
def bind_expression(self, bindvalue):
return func.ST_GeomFromGeoJSON(json.dumps(bindvalue), type_=self)
def column_expression(self, col):
return json.loads(func.ST_AsGeoJSON(col, type_=self))
class OvertakingEvent(Base):
__tablename__ = "overtaking_event"
__table_args__ = (Index("road_segment", "way_id", "direction_reversed"),)
id = Column(Integer, autoincrement=True, primary_key=True, index=True)
track_id = Column(String, index=True)
hex_hash = Column(String, unique=True, index=True)
way_id = Column(Integer, index=True)
# whether we were traveling along the way in reverse direction
direction_reversed = Column(Boolean)
geometry = Column(Geometry)
latitude = Column(Float)
longitude = Column(Float)
time = Column(DateTime)
distance_overtaker = Column(Float)
distance_stationary = Column(Float)
course = Column(Float)
speed = Column(Float)
def __repr__(self):
return f"<OvertakingEvent {self.id}>"
class Road(Base):
__tablename__ = "road"
way_id = Column(Integer, primary_key=True, index=True)
zone = Column(ZoneType)
name = Column(String)
geometry = Column(Geometry)
class RoadSegment(Base):
__tablename__ = "bike_lane"
way_id = Column(Integer, primary_key=True, index=True)
direction_reversed = Column(Boolean)
geometry = Column(Geometry)

View file

@ -23,10 +23,10 @@ from obs.face.filter import (
PrivacyZonesFilter, PrivacyZonesFilter,
RequiredFieldsFilter, RequiredFieldsFilter,
) )
from obs.face.osm import DataSource, OverpassTileSource from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource
from sqlalchemy import delete, func, select from sqlalchemy import delete
from db import make_session, connect_db, OvertakingEvent from obs.face.db import make_session, connect_db, OvertakingEvent, async_session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -64,25 +64,27 @@ async def main():
postgres_url_default = os.environ.get("POSTGRES_URL") postgres_url_default = os.environ.get("POSTGRES_URL")
parser.add_argument( parser.add_argument(
"--postgres-url", "--postgres-url",
required=False, required=not postgres_url_default,
action="store", action="store",
help="connection string for postgres database, if set, the track result is imported there", help="connection string for postgres database",
default=postgres_url_default, default=postgres_url_default,
) )
args = parser.parse_args() args = parser.parse_args()
if args.cache_dir is None: async with connect_db(args.postgres_url):
with tempfile.TemporaryDirectory() as cache_dir: if args.cache_dir is None:
args.cache_dir = cache_dir with tempfile.TemporaryDirectory() as cache_dir:
args.cache_dir = cache_dir
await process(args)
else:
await process(args) await process(args)
else:
await process(args)
async def process(args): async def process(args):
log.info("Loading OpenStreetMap data") log.info("Loading OpenStreetMap data")
tile_source = OverpassTileSource(cache_dir=args.cache_dir) tile_source = DatabaseTileSource(async_session.get())
# tile_source = OverpassTileSource(args.cache_dir)
data_source = DataSource(tile_source) data_source = DataSource(tile_source)
filename_input = os.path.abspath(args.input) filename_input = os.path.abspath(args.input)
@ -109,9 +111,9 @@ async def process(args):
dataset_id=dataset_id, dataset_id=dataset_id,
) )
input_data = AnnotateMeasurements(data_source, cache_dir=args.cache_dir).annotate( input_data = await AnnotateMeasurements(
imported_data data_source, cache_dir=args.cache_dir
) ).annotate(imported_data)
filters_from_settings = [] filters_from_settings = []
for filter_description in settings.get("filters", []): for filter_description in settings.get("filters", []):
@ -145,12 +147,12 @@ async def process(args):
overtaking_events = overtaking_events_filter.filter(measurements, log=log) overtaking_events = overtaking_events_filter.filter(measurements, log=log)
exporter = ExportMeasurements("measurements.dummy") exporter = ExportMeasurements("measurements.dummy")
exporter.add_measurements(measurements) await exporter.add_measurements(measurements)
measurements_json = exporter.get_data() measurements_json = exporter.get_data()
del exporter del exporter
exporter = ExportMeasurements("overtaking_events.dummy") exporter = ExportMeasurements("overtaking_events.dummy")
exporter.add_measurements(overtaking_events) await exporter.add_measurements(overtaking_events)
overtaking_events_json = exporter.get_data() overtaking_events_json = exporter.get_data()
del exporter del exporter
@ -163,8 +165,12 @@ async def process(args):
} }
statistics_json = { statistics_json = {
"recordedAt": statistics["t_min"].isoformat(), "recordedAt": statistics["t_min"].isoformat()
"recordedUntil": statistics["t_max"].isoformat(), if statistics["t_min"] is not None
else None,
"recordedUntil": statistics["t_max"].isoformat()
if statistics["t_max"] is not None
else None,
"duration": statistics["t"], "duration": statistics["t"],
"length": statistics["d"], "length": statistics["d"],
"segments": statistics["n_segments"], "segments": statistics["n_segments"],
@ -182,15 +188,11 @@ async def process(args):
with open(os.path.join(args.output, output_filename), "w") as fp: with open(os.path.join(args.output, output_filename), "w") as fp:
json.dump(data, fp, indent=4) json.dump(data, fp, indent=4)
if args.postgres_url: log.info("Importing to database.")
log.info("Importing to database.") async with make_session() as session:
async with connect_db(args.postgres_url): await clear_track_data(session, settings["trackId"])
async with make_session() as session: await import_overtaking_events(session, settings["trackId"], overtaking_events)
await clear_track_data(session, settings["trackId"]) await session.commit()
await import_overtaking_events(
session, settings["trackId"], overtaking_events
)
await session.commit()
async def clear_track_data(session, track_id): async def clear_track_data(session, track_id):
@ -213,7 +215,12 @@ async def import_overtaking_events(session, track_id, overtaking_events):
hex_hash=hex_hash, hex_hash=hex_hash,
way_id=m["OSM_way_id"], way_id=m["OSM_way_id"],
direction_reversed=m["OSM_way_orientation"] < 0, direction_reversed=m["OSM_way_orientation"] < 0,
geometry={"type": "Point", "coordinates": [m["longitude"], m["latitude"]]}, geometry=json.dumps(
{
"type": "Point",
"coordinates": [m["longitude"], m["latitude"]],
}
),
latitude=m["latitude"], latitude=m["latitude"],
longitude=m["longitude"], longitude=m["longitude"],
time=m["time"].astimezone(pytz.utc).replace(tzinfo=None), time=m["time"].astimezone(pytz.utc).replace(tzinfo=None),

View file

@ -4,7 +4,7 @@ import logging
import os import os
import asyncio import asyncio
from db import init_models, connect_db from obs.face.db import init_models, connect_db
log = logging.getLogger(__name__) log = logging.getLogger(__name__)