obs-portal/api/tools/prepare_sql_tiles.py

198 lines
5.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import logging
import asyncio
import tempfile
import re
import os
import glob
from os.path import normpath, abspath, join
2022-04-30 18:31:32 +00:00
from typing import List, Tuple
from sqlalchemy import text
import sqlparse
2022-04-30 18:31:32 +00:00
from openmaptiles.sqltomvt import MvtGenerator
from obs.api.app import app
from obs.api.db import connect_db, make_session
log = logging.getLogger(__name__)
TILE_GENERATOR = normpath(
abspath(join(app.config.API_ROOT_DIR, "..", "tile-generator"))
)
TILESET_FILE = join(TILE_GENERATOR, "openbikesensor.yaml")
2022-04-30 18:31:32 +00:00
EXTRA_ARGS = [
# name, type, default
("user_id", "integer", "NULL"),
("min_time", "timestamp", "NULL"),
("max_time", "timestamp", "NULL"),
]
class CustomMvtGenerator(MvtGenerator):
def generate_sqltomvt_func(self, fname, extra_args: List[Tuple[str, str]]) -> str:
"""
Creates a SQL function that returns a single bytea value or null. This
method is overridden to allow for custom arguments in the created function
"""
extra_args_types = "".join([f", {a[1]}" for a in extra_args])
extra_args_definitions = "".join(
[f", {a[0]} {a[1]} DEFAULT {a[2]}" for a in extra_args]
)
return f"""\
DROP FUNCTION IF EXISTS {fname}(integer, integer, integer{extra_args_types});
CREATE FUNCTION {fname}(zoom integer, x integer, y integer{extra_args_definitions})
RETURNS {'TABLE(mvt bytea, key text)' if self.key_column else 'bytea'} AS $$
{self.generate_sql()};
$$ LANGUAGE SQL STABLE CALLED ON NULL INPUT;"""
2022-04-30 18:31:32 +00:00
def parse_pg_url(url=app.config.POSTGRES_URL):
m = re.match(
r"^postgresql\+asyncpg://(?P<user>.*):(?P<password>.*)@(?P<host>.*)(:(?P<port>\d+))?/(?P<database>[^/]+)$",
url,
)
return (
m["user"] or "",
m["password"] or "",
m["host"],
m["port"] or "5432",
m["database"],
)
async def main():
logging.basicConfig(level=logging.DEBUG, format="%(levelname)s: %(message)s")
await prepare_sql_tiles()
async def prepare_sql_tiles():
with tempfile.TemporaryDirectory() as build_dir:
await generate_data_yml(build_dir)
sql_snippets = await generate_sql(build_dir)
await import_sql(sql_snippets)
async def _run(cmd):
if isinstance(cmd, list):
cmd = " ".join(cmd)
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
log.error(stderr.decode("utf-8"))
raise RuntimeError("external program failed: %s" % str(cmd))
return stdout.decode("utf-8")
async def generate_data_yml(build_dir):
stdout = await _run(
[
"python",
"$(which generate-tm2source)",
TILESET_FILE,
*sum(
zip(
["--user", "--password", "--host", "--port", "--database"],
parse_pg_url(),
),
(),
),
]
)
tm2source = join(build_dir, "openbikesensor.tm2source")
os.makedirs(tm2source, exist_ok=True)
with open(join(tm2source, "data.yml"), "wt") as f:
f.write(stdout)
async def generate_sql(build_dir):
sql_dir = join(build_dir, "sql")
await _run(f"python $(which generate-sql) {TILESET_FILE!r} --dir {sql_dir!r}")
sql_snippet_files = [
*sorted(
glob.glob(
join(
app.config.API_ROOT_DIR, "src", "openmaptiles-tools", "sql", "*.sql"
)
)
),
join(sql_dir, "run_first.sql"),
*sorted(glob.glob(join(sql_dir, "parallel", "*.sql"))),
join(sql_dir, "run_last.sql"),
]
sql_snippets = [
"CREATE EXTENSION IF NOT EXISTS hstore;"
"CREATE EXTENSION IF NOT EXISTS postgis;"
]
for filename in sql_snippet_files:
with open(filename, "rt") as f:
sql_snippets.append(f.read())
2022-04-30 18:31:32 +00:00
mvt = CustomMvtGenerator(
tileset=TILESET_FILE,
postgis_ver="3.0.1",
zoom="zoom",
x="x",
y="y",
gzip=True,
test_geometry=False, # ?
key_column=True,
)
2022-04-30 18:31:32 +00:00
getmvt_sql = mvt.generate_sqltomvt_func("getmvt", EXTRA_ARGS)
print(getmvt_sql)
# drop old versions of the function
sql_snippets.append("DROP FUNCTION IF EXISTS getmvt(integer, integer, integer);")
sql_snippets.append(getmvt_sql)
return sql_snippets
async def import_sql(sql_snippets):
statements = sum(map(sqlparse.split, sql_snippets), [])
async with connect_db(
app.config.POSTGRES_URL,
app.config.POSTGRES_POOL_SIZE,
app.config.POSTGRES_MAX_OVERFLOW,
):
for i, statement in enumerate(statements):
clean_statement = sqlparse.format(
statement,
truncate_strings=20,
strip_comments=True,
keyword_case="upper",
)
if not clean_statement:
continue
log.debug(
"Running SQL statement %d of %d (%s...)",
i + 1,
len(statements),
clean_statement[:40],
)
async with make_session() as session:
await session.execute(text(statement))
await session.commit()
if __name__ == "__main__":
asyncio.run(main())