api: Update to postgis version of scripts
This commit is contained in:
parent
79f3469df8
commit
f2fa806cab
|
@ -88,7 +88,7 @@ function osm2pgsql.process_way(object)
|
||||||
end
|
end
|
||||||
|
|
||||||
roads:add_row({
|
roads:add_row({
|
||||||
geom = { create = 'linear' },
|
geometry = { create = 'linear' },
|
||||||
name = tags.name,
|
name = tags.name,
|
||||||
zone = zone,
|
zone = zone,
|
||||||
tags = tags,
|
tags = tags,
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 118cc1d9f9dbd1dd8816a61c0698deaf404cf0ff
|
Subproject commit 767b7166105b2446c3b4e0334222052114b0ebae
|
112
api/tools/db.py
112
api/tools/db.py
|
@ -1,112 +0,0 @@
|
||||||
from contextvars import ContextVar
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.types import UserDefinedType
|
|
||||||
from sqlalchemy import (
|
|
||||||
Column,
|
|
||||||
String,
|
|
||||||
Integer,
|
|
||||||
Boolean,
|
|
||||||
select,
|
|
||||||
DateTime,
|
|
||||||
Float,
|
|
||||||
Index,
|
|
||||||
Enum as SqlEnum,
|
|
||||||
func,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
engine = ContextVar("engine")
|
|
||||||
async_session = ContextVar("async_session")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def make_session():
|
|
||||||
async with async_session.get()() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
async def init_models():
|
|
||||||
async with engine.get().begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.drop_all)
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def connect_db(url):
|
|
||||||
engine_ = create_async_engine(url, echo=False)
|
|
||||||
t1 = engine.set(engine_)
|
|
||||||
|
|
||||||
async_session_ = sessionmaker(engine_, class_=AsyncSession, expire_on_commit=False)
|
|
||||||
t2 = async_session.set(async_session_)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
# for AsyncEngine created in function scope, close and
|
|
||||||
# clean-up pooled connections
|
|
||||||
await engine_.dispose()
|
|
||||||
engine.reset(t1)
|
|
||||||
async_session.reset(t2)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ZoneType = SqlEnum("rural", "urban", "motorway", name="zone_type")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Geometry(UserDefinedType):
|
|
||||||
def get_col_spec(self):
|
|
||||||
return "GEOMETRY"
|
|
||||||
|
|
||||||
def bind_expression(self, bindvalue):
|
|
||||||
return func.ST_GeomFromGeoJSON(json.dumps(bindvalue), type_=self)
|
|
||||||
|
|
||||||
def column_expression(self, col):
|
|
||||||
return json.loads(func.ST_AsGeoJSON(col, type_=self))
|
|
||||||
|
|
||||||
|
|
||||||
class OvertakingEvent(Base):
|
|
||||||
__tablename__ = "overtaking_event"
|
|
||||||
__table_args__ = (Index("road_segment", "way_id", "direction_reversed"),)
|
|
||||||
|
|
||||||
id = Column(Integer, autoincrement=True, primary_key=True, index=True)
|
|
||||||
track_id = Column(String, index=True)
|
|
||||||
hex_hash = Column(String, unique=True, index=True)
|
|
||||||
way_id = Column(Integer, index=True)
|
|
||||||
|
|
||||||
# whether we were traveling along the way in reverse direction
|
|
||||||
direction_reversed = Column(Boolean)
|
|
||||||
|
|
||||||
geometry = Column(Geometry)
|
|
||||||
latitude = Column(Float)
|
|
||||||
longitude = Column(Float)
|
|
||||||
time = Column(DateTime)
|
|
||||||
distance_overtaker = Column(Float)
|
|
||||||
distance_stationary = Column(Float)
|
|
||||||
course = Column(Float)
|
|
||||||
speed = Column(Float)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<OvertakingEvent {self.id}>"
|
|
||||||
|
|
||||||
|
|
||||||
class Road(Base):
|
|
||||||
__tablename__ = "road"
|
|
||||||
way_id = Column(Integer, primary_key=True, index=True)
|
|
||||||
zone = Column(ZoneType)
|
|
||||||
name = Column(String)
|
|
||||||
geometry = Column(Geometry)
|
|
||||||
|
|
||||||
|
|
||||||
class RoadSegment(Base):
|
|
||||||
__tablename__ = "bike_lane"
|
|
||||||
way_id = Column(Integer, primary_key=True, index=True)
|
|
||||||
direction_reversed = Column(Boolean)
|
|
||||||
geometry = Column(Geometry)
|
|
|
@ -23,10 +23,10 @@ from obs.face.filter import (
|
||||||
PrivacyZonesFilter,
|
PrivacyZonesFilter,
|
||||||
RequiredFieldsFilter,
|
RequiredFieldsFilter,
|
||||||
)
|
)
|
||||||
from obs.face.osm import DataSource, OverpassTileSource
|
from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete
|
||||||
|
|
||||||
from db import make_session, connect_db, OvertakingEvent
|
from obs.face.db import make_session, connect_db, OvertakingEvent, async_session
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -64,25 +64,27 @@ async def main():
|
||||||
postgres_url_default = os.environ.get("POSTGRES_URL")
|
postgres_url_default = os.environ.get("POSTGRES_URL")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--postgres-url",
|
"--postgres-url",
|
||||||
required=False,
|
required=not postgres_url_default,
|
||||||
action="store",
|
action="store",
|
||||||
help="connection string for postgres database, if set, the track result is imported there",
|
help="connection string for postgres database",
|
||||||
default=postgres_url_default,
|
default=postgres_url_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.cache_dir is None:
|
async with connect_db(args.postgres_url):
|
||||||
with tempfile.TemporaryDirectory() as cache_dir:
|
if args.cache_dir is None:
|
||||||
args.cache_dir = cache_dir
|
with tempfile.TemporaryDirectory() as cache_dir:
|
||||||
|
args.cache_dir = cache_dir
|
||||||
|
await process(args)
|
||||||
|
else:
|
||||||
await process(args)
|
await process(args)
|
||||||
else:
|
|
||||||
await process(args)
|
|
||||||
|
|
||||||
|
|
||||||
async def process(args):
|
async def process(args):
|
||||||
log.info("Loading OpenStreetMap data")
|
log.info("Loading OpenStreetMap data")
|
||||||
tile_source = OverpassTileSource(cache_dir=args.cache_dir)
|
tile_source = DatabaseTileSource(async_session.get())
|
||||||
|
# tile_source = OverpassTileSource(args.cache_dir)
|
||||||
data_source = DataSource(tile_source)
|
data_source = DataSource(tile_source)
|
||||||
|
|
||||||
filename_input = os.path.abspath(args.input)
|
filename_input = os.path.abspath(args.input)
|
||||||
|
@ -109,9 +111,9 @@ async def process(args):
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data = AnnotateMeasurements(data_source, cache_dir=args.cache_dir).annotate(
|
input_data = await AnnotateMeasurements(
|
||||||
imported_data
|
data_source, cache_dir=args.cache_dir
|
||||||
)
|
).annotate(imported_data)
|
||||||
|
|
||||||
filters_from_settings = []
|
filters_from_settings = []
|
||||||
for filter_description in settings.get("filters", []):
|
for filter_description in settings.get("filters", []):
|
||||||
|
@ -145,12 +147,12 @@ async def process(args):
|
||||||
overtaking_events = overtaking_events_filter.filter(measurements, log=log)
|
overtaking_events = overtaking_events_filter.filter(measurements, log=log)
|
||||||
|
|
||||||
exporter = ExportMeasurements("measurements.dummy")
|
exporter = ExportMeasurements("measurements.dummy")
|
||||||
exporter.add_measurements(measurements)
|
await exporter.add_measurements(measurements)
|
||||||
measurements_json = exporter.get_data()
|
measurements_json = exporter.get_data()
|
||||||
del exporter
|
del exporter
|
||||||
|
|
||||||
exporter = ExportMeasurements("overtaking_events.dummy")
|
exporter = ExportMeasurements("overtaking_events.dummy")
|
||||||
exporter.add_measurements(overtaking_events)
|
await exporter.add_measurements(overtaking_events)
|
||||||
overtaking_events_json = exporter.get_data()
|
overtaking_events_json = exporter.get_data()
|
||||||
del exporter
|
del exporter
|
||||||
|
|
||||||
|
@ -163,8 +165,12 @@ async def process(args):
|
||||||
}
|
}
|
||||||
|
|
||||||
statistics_json = {
|
statistics_json = {
|
||||||
"recordedAt": statistics["t_min"].isoformat(),
|
"recordedAt": statistics["t_min"].isoformat()
|
||||||
"recordedUntil": statistics["t_max"].isoformat(),
|
if statistics["t_min"] is not None
|
||||||
|
else None,
|
||||||
|
"recordedUntil": statistics["t_max"].isoformat()
|
||||||
|
if statistics["t_max"] is not None
|
||||||
|
else None,
|
||||||
"duration": statistics["t"],
|
"duration": statistics["t"],
|
||||||
"length": statistics["d"],
|
"length": statistics["d"],
|
||||||
"segments": statistics["n_segments"],
|
"segments": statistics["n_segments"],
|
||||||
|
@ -182,15 +188,11 @@ async def process(args):
|
||||||
with open(os.path.join(args.output, output_filename), "w") as fp:
|
with open(os.path.join(args.output, output_filename), "w") as fp:
|
||||||
json.dump(data, fp, indent=4)
|
json.dump(data, fp, indent=4)
|
||||||
|
|
||||||
if args.postgres_url:
|
log.info("Importing to database.")
|
||||||
log.info("Importing to database.")
|
async with make_session() as session:
|
||||||
async with connect_db(args.postgres_url):
|
await clear_track_data(session, settings["trackId"])
|
||||||
async with make_session() as session:
|
await import_overtaking_events(session, settings["trackId"], overtaking_events)
|
||||||
await clear_track_data(session, settings["trackId"])
|
await session.commit()
|
||||||
await import_overtaking_events(
|
|
||||||
session, settings["trackId"], overtaking_events
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
async def clear_track_data(session, track_id):
|
async def clear_track_data(session, track_id):
|
||||||
|
@ -213,7 +215,12 @@ async def import_overtaking_events(session, track_id, overtaking_events):
|
||||||
hex_hash=hex_hash,
|
hex_hash=hex_hash,
|
||||||
way_id=m["OSM_way_id"],
|
way_id=m["OSM_way_id"],
|
||||||
direction_reversed=m["OSM_way_orientation"] < 0,
|
direction_reversed=m["OSM_way_orientation"] < 0,
|
||||||
geometry={"type": "Point", "coordinates": [m["longitude"], m["latitude"]]},
|
geometry=json.dumps(
|
||||||
|
{
|
||||||
|
"type": "Point",
|
||||||
|
"coordinates": [m["longitude"], m["latitude"]],
|
||||||
|
}
|
||||||
|
),
|
||||||
latitude=m["latitude"],
|
latitude=m["latitude"],
|
||||||
longitude=m["longitude"],
|
longitude=m["longitude"],
|
||||||
time=m["time"].astimezone(pytz.utc).replace(tzinfo=None),
|
time=m["time"].astimezone(pytz.utc).replace(tzinfo=None),
|
||||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from db import init_models, connect_db
|
from obs.face.db import init_models, connect_db
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue