api: Update to postgis version of scripts
This commit is contained in:
parent
79f3469df8
commit
f2fa806cab
|
@ -88,7 +88,7 @@ function osm2pgsql.process_way(object)
|
|||
end
|
||||
|
||||
roads:add_row({
|
||||
geom = { create = 'linear' },
|
||||
geometry = { create = 'linear' },
|
||||
name = tags.name,
|
||||
zone = zone,
|
||||
tags = tags,
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 118cc1d9f9dbd1dd8816a61c0698deaf404cf0ff
|
||||
Subproject commit 767b7166105b2446c3b4e0334222052114b0ebae
|
112
api/tools/db.py
112
api/tools/db.py
|
@ -1,112 +0,0 @@
|
|||
from contextvars import ContextVar
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.types import UserDefinedType
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
String,
|
||||
Integer,
|
||||
Boolean,
|
||||
select,
|
||||
DateTime,
|
||||
Float,
|
||||
Index,
|
||||
Enum as SqlEnum,
|
||||
func,
|
||||
)
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
engine = ContextVar("engine")
|
||||
async_session = ContextVar("async_session")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_session():
|
||||
async with async_session.get()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def init_models():
|
||||
async with engine.get().begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_db(url):
|
||||
engine_ = create_async_engine(url, echo=False)
|
||||
t1 = engine.set(engine_)
|
||||
|
||||
async_session_ = sessionmaker(engine_, class_=AsyncSession, expire_on_commit=False)
|
||||
t2 = async_session.set(async_session_)
|
||||
|
||||
yield
|
||||
|
||||
# for AsyncEngine created in function scope, close and
|
||||
# clean-up pooled connections
|
||||
await engine_.dispose()
|
||||
engine.reset(t1)
|
||||
async_session.reset(t2)
|
||||
|
||||
|
||||
|
||||
ZoneType = SqlEnum("rural", "urban", "motorway", name="zone_type")
|
||||
|
||||
|
||||
|
||||
class Geometry(UserDefinedType):
|
||||
def get_col_spec(self):
|
||||
return "GEOMETRY"
|
||||
|
||||
def bind_expression(self, bindvalue):
|
||||
return func.ST_GeomFromGeoJSON(json.dumps(bindvalue), type_=self)
|
||||
|
||||
def column_expression(self, col):
|
||||
return json.loads(func.ST_AsGeoJSON(col, type_=self))
|
||||
|
||||
|
||||
class OvertakingEvent(Base):
|
||||
__tablename__ = "overtaking_event"
|
||||
__table_args__ = (Index("road_segment", "way_id", "direction_reversed"),)
|
||||
|
||||
id = Column(Integer, autoincrement=True, primary_key=True, index=True)
|
||||
track_id = Column(String, index=True)
|
||||
hex_hash = Column(String, unique=True, index=True)
|
||||
way_id = Column(Integer, index=True)
|
||||
|
||||
# whether we were traveling along the way in reverse direction
|
||||
direction_reversed = Column(Boolean)
|
||||
|
||||
geometry = Column(Geometry)
|
||||
latitude = Column(Float)
|
||||
longitude = Column(Float)
|
||||
time = Column(DateTime)
|
||||
distance_overtaker = Column(Float)
|
||||
distance_stationary = Column(Float)
|
||||
course = Column(Float)
|
||||
speed = Column(Float)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OvertakingEvent {self.id}>"
|
||||
|
||||
|
||||
class Road(Base):
|
||||
__tablename__ = "road"
|
||||
way_id = Column(Integer, primary_key=True, index=True)
|
||||
zone = Column(ZoneType)
|
||||
name = Column(String)
|
||||
geometry = Column(Geometry)
|
||||
|
||||
|
||||
class RoadSegment(Base):
|
||||
__tablename__ = "bike_lane"
|
||||
way_id = Column(Integer, primary_key=True, index=True)
|
||||
direction_reversed = Column(Boolean)
|
||||
geometry = Column(Geometry)
|
|
@ -23,10 +23,10 @@ from obs.face.filter import (
|
|||
PrivacyZonesFilter,
|
||||
RequiredFieldsFilter,
|
||||
)
|
||||
from obs.face.osm import DataSource, OverpassTileSource
|
||||
from sqlalchemy import delete, func, select
|
||||
from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource
|
||||
from sqlalchemy import delete
|
||||
|
||||
from db import make_session, connect_db, OvertakingEvent
|
||||
from obs.face.db import make_session, connect_db, OvertakingEvent, async_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -64,14 +64,15 @@ async def main():
|
|||
postgres_url_default = os.environ.get("POSTGRES_URL")
|
||||
parser.add_argument(
|
||||
"--postgres-url",
|
||||
required=False,
|
||||
required=not postgres_url_default,
|
||||
action="store",
|
||||
help="connection string for postgres database, if set, the track result is imported there",
|
||||
help="connection string for postgres database",
|
||||
default=postgres_url_default,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
async with connect_db(args.postgres_url):
|
||||
if args.cache_dir is None:
|
||||
with tempfile.TemporaryDirectory() as cache_dir:
|
||||
args.cache_dir = cache_dir
|
||||
|
@ -82,7 +83,8 @@ async def main():
|
|||
|
||||
async def process(args):
|
||||
log.info("Loading OpenStreetMap data")
|
||||
tile_source = OverpassTileSource(cache_dir=args.cache_dir)
|
||||
tile_source = DatabaseTileSource(async_session.get())
|
||||
# tile_source = OverpassTileSource(args.cache_dir)
|
||||
data_source = DataSource(tile_source)
|
||||
|
||||
filename_input = os.path.abspath(args.input)
|
||||
|
@ -109,9 +111,9 @@ async def process(args):
|
|||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
input_data = AnnotateMeasurements(data_source, cache_dir=args.cache_dir).annotate(
|
||||
imported_data
|
||||
)
|
||||
input_data = await AnnotateMeasurements(
|
||||
data_source, cache_dir=args.cache_dir
|
||||
).annotate(imported_data)
|
||||
|
||||
filters_from_settings = []
|
||||
for filter_description in settings.get("filters", []):
|
||||
|
@ -145,12 +147,12 @@ async def process(args):
|
|||
overtaking_events = overtaking_events_filter.filter(measurements, log=log)
|
||||
|
||||
exporter = ExportMeasurements("measurements.dummy")
|
||||
exporter.add_measurements(measurements)
|
||||
await exporter.add_measurements(measurements)
|
||||
measurements_json = exporter.get_data()
|
||||
del exporter
|
||||
|
||||
exporter = ExportMeasurements("overtaking_events.dummy")
|
||||
exporter.add_measurements(overtaking_events)
|
||||
await exporter.add_measurements(overtaking_events)
|
||||
overtaking_events_json = exporter.get_data()
|
||||
del exporter
|
||||
|
||||
|
@ -163,8 +165,12 @@ async def process(args):
|
|||
}
|
||||
|
||||
statistics_json = {
|
||||
"recordedAt": statistics["t_min"].isoformat(),
|
||||
"recordedUntil": statistics["t_max"].isoformat(),
|
||||
"recordedAt": statistics["t_min"].isoformat()
|
||||
if statistics["t_min"] is not None
|
||||
else None,
|
||||
"recordedUntil": statistics["t_max"].isoformat()
|
||||
if statistics["t_max"] is not None
|
||||
else None,
|
||||
"duration": statistics["t"],
|
||||
"length": statistics["d"],
|
||||
"segments": statistics["n_segments"],
|
||||
|
@ -182,14 +188,10 @@ async def process(args):
|
|||
with open(os.path.join(args.output, output_filename), "w") as fp:
|
||||
json.dump(data, fp, indent=4)
|
||||
|
||||
if args.postgres_url:
|
||||
log.info("Importing to database.")
|
||||
async with connect_db(args.postgres_url):
|
||||
async with make_session() as session:
|
||||
await clear_track_data(session, settings["trackId"])
|
||||
await import_overtaking_events(
|
||||
session, settings["trackId"], overtaking_events
|
||||
)
|
||||
await import_overtaking_events(session, settings["trackId"], overtaking_events)
|
||||
await session.commit()
|
||||
|
||||
|
||||
|
@ -213,7 +215,12 @@ async def import_overtaking_events(session, track_id, overtaking_events):
|
|||
hex_hash=hex_hash,
|
||||
way_id=m["OSM_way_id"],
|
||||
direction_reversed=m["OSM_way_orientation"] < 0,
|
||||
geometry={"type": "Point", "coordinates": [m["longitude"], m["latitude"]]},
|
||||
geometry=json.dumps(
|
||||
{
|
||||
"type": "Point",
|
||||
"coordinates": [m["longitude"], m["latitude"]],
|
||||
}
|
||||
),
|
||||
latitude=m["latitude"],
|
||||
longitude=m["longitude"],
|
||||
time=m["time"].astimezone(pytz.utc).replace(tzinfo=None),
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
import asyncio
|
||||
|
||||
from db import init_models, connect_db
|
||||
from obs.face.db import init_models, connect_db
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
Loading…
Reference in a new issue