obs-portal/api/tools/db.py
2021-11-22 09:50:33 +01:00

113 lines
2.8 KiB
Python

from contextvars import ContextVar
from contextlib import asynccontextmanager
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.types import UserDefinedType
from sqlalchemy import (
Column,
String,
Integer,
Boolean,
select,
DateTime,
Float,
Index,
Enum as SqlEnum,
func,
)
Base = declarative_base()
engine = ContextVar("engine")
async_session = ContextVar("async_session")
@asynccontextmanager
async def make_session():
async with async_session.get()() as session:
yield session
async def init_models():
async with engine.get().begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
@asynccontextmanager
async def connect_db(url):
engine_ = create_async_engine(url, echo=False)
t1 = engine.set(engine_)
async_session_ = sessionmaker(engine_, class_=AsyncSession, expire_on_commit=False)
t2 = async_session.set(async_session_)
yield
# for AsyncEngine created in function scope, close and
# clean-up pooled connections
await engine_.dispose()
engine.reset(t1)
async_session.reset(t2)
ZoneType = SqlEnum("rural", "urban", "motorway", name="zone_type")
class Geometry(UserDefinedType):
def get_col_spec(self):
return "GEOMETRY"
def bind_expression(self, bindvalue):
return func.ST_GeomFromGeoJSON(json.dumps(bindvalue), type_=self)
def column_expression(self, col):
return json.loads(func.ST_AsGeoJSON(col, type_=self))
class OvertakingEvent(Base):
__tablename__ = "overtaking_event"
__table_args__ = (Index("road_segment", "way_id", "direction_reversed"),)
id = Column(Integer, autoincrement=True, primary_key=True, index=True)
track_id = Column(String, index=True)
hex_hash = Column(String, unique=True, index=True)
way_id = Column(Integer, index=True)
# whether we were traveling along the way in reverse direction
direction_reversed = Column(Boolean)
geometry = Column(Geometry)
latitude = Column(Float)
longitude = Column(Float)
time = Column(DateTime)
distance_overtaker = Column(Float)
distance_stationary = Column(Float)
course = Column(Float)
speed = Column(Float)
def __repr__(self):
return f"<OvertakingEvent {self.id}>"
class Road(Base):
__tablename__ = "road"
way_id = Column(Integer, primary_key=True, index=True)
zone = Column(ZoneType)
name = Column(String)
geometry = Column(Geometry)
class RoadSegment(Base):
__tablename__ = "bike_lane"
way_id = Column(Integer, primary_key=True, index=True)
direction_reversed = Column(Boolean)
geometry = Column(Geometry)