diff --git a/api/migrations/versions/f7b21148126a_add_user_device.py b/api/migrations/versions/f7b21148126a_add_user_device.py new file mode 100644 index 0000000..2e65451 --- /dev/null +++ b/api/migrations/versions/f7b21148126a_add_user_device.py @@ -0,0 +1,41 @@ +"""add user_device + +Revision ID: f7b21148126a +Revises: a9627f63fbed +Create Date: 2022-09-15 17:48:06.764342 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "f7b21148126a" +down_revision = "a049e5eb24dd" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "user_device", + sa.Column("id", sa.Integer, autoincrement=True, primary_key=True), + sa.Column("user_id", sa.Integer, sa.ForeignKey("user.id", ondelete="CASCADE")), + sa.Column("identifier", sa.String, nullable=False), + sa.Column("display_name", sa.String, nullable=True), + sa.Index("user_id_identifier", "user_id", "identifier", unique=True), + ) + op.add_column( + "track", + sa.Column( + "user_device_id", + sa.Integer, + sa.ForeignKey("user_device.id", ondelete="RESTRICT"), + nullable=True, + ), + ) + + +def downgrade(): + op.drop_column("track", "user_device_id") + op.drop_table("user_device") diff --git a/api/obs/api/db.py b/api/obs/api/db.py index f7716a6..74b5b20 100644 --- a/api/obs/api/db.py +++ b/api/obs/api/db.py @@ -221,6 +221,12 @@ class Track(Base): Integer, ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) + user_device_id = Column( + Integer, + ForeignKey("user_device.id", ondelete="RESTRICT"), + nullable=True, + ) + # Statistics... maybe we'll drop some of this if we can easily compute them from SQL recorded_at = Column(DateTime) recorded_until = Column(DateTime) @@ -409,6 +415,28 @@ class User(Base): self.username = new_name +class UserDevice(Base): + __tablename__ = "user_device" + id = Column(Integer, autoincrement=True, primary_key=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + identifier = Column(String, nullable=False) + display_name = Column(String, nullable=True) + + __table_args__ = ( + Index("user_id_identifier", "user_id", "identifier", unique=True), + ) + + def to_dict(self, for_user_id=None): + if for_user_id != self.user_id: + return {} + + return { + "id": self.id, + "identifier": self.identifier, + "displayName": self.display_name, + } + + class Comment(Base): __tablename__ = "comment" id = Column(Integer, autoincrement=True, primary_key=True) @@ -468,6 +496,14 @@ Track.overtaking_events = relationship( passive_deletes=True, ) +Track.user_device = relationship("UserDevice", back_populates="tracks") +UserDevice.tracks = relationship( + "Track", + order_by=Track.created_at, + back_populates="user_device", + passive_deletes=False, +) + # 0..4 Night, 4..10 Morning, 10..14 Noon, 14..18 Afternoon, 18..22 Evening, 22..00 Night # Two hour intervals diff --git a/api/obs/api/process.py b/api/obs/api/process.py index 6fc2c5e..79a39d8 100644 --- a/api/obs/api/process.py +++ b/api/obs/api/process.py @@ -8,7 +8,7 @@ import pytz from os.path import join from datetime import datetime -from sqlalchemy import delete, select +from sqlalchemy import delete, select, and_ from sqlalchemy.orm import joinedload from obs.face.importer import ImportMeasurementsCsv @@ -27,7 +27,7 @@ from obs.face.filter import ( from obs.face.osm import DataSource, DatabaseTileSource, OverpassTileSource -from obs.api.db import OvertakingEvent, RoadUsage, Track, make_session +from obs.api.db import OvertakingEvent, RoadUsage, Track, UserDevice, make_session from obs.api.app import app log = logging.getLogger(__name__) @@ -144,10 +144,11 @@ async def process_track(session, track, data_source): os.makedirs(output_dir, exist_ok=True) log.info("Annotating and filtering CSV file") - imported_data, statistics = ImportMeasurementsCsv().read( + imported_data, statistics, track_metadata = ImportMeasurementsCsv().read( original_file_path, user_id="dummy", # TODO: user username or id or nothing? dataset_id=Track.slug, # TODO: use track id or slug or nothing? + return_metadata=True, ) annotator = AnnotateMeasurements( @@ -217,6 +218,36 @@ async def process_track(session, track, data_source): await clear_track_data(session, track) await session.commit() + device_identifier = track_metadata.get("DeviceId") + if device_identifier: + if isinstance(device_identifier, list): + device_identifier = device_identifier[0] + + log.info("Finding or creating device %s", device_identifier) + user_device = ( + await session.execute( + select(UserDevice).where( + and_( + UserDevice.user_id == track.author_id, + UserDevice.identifier == device_identifier, + ) + ) + ) + ).scalar() + + log.debug("user_device is %s", user_device) + + if not user_device: + user_device = UserDevice( + user_id=track.author_id, identifier=device_identifier + ) + log.debug("Create new device for this user") + session.add(user_device) + + track.user_device = user_device + else: + log.info("No DeviceId in track metadata.") + log.info("Import events into database...") await import_overtaking_events(session, track, overtaking_events) diff --git a/api/scripts b/api/scripts index 8e9395f..bbc6fec 160000 --- a/api/scripts +++ b/api/scripts @@ -1 +1 @@ -Subproject commit 8e9395fd3cd0f1e83b4413546bc2d3cb0c726738 +Subproject commit bbc6feca08aee9ea4f4263bb7c07e199d9c989ee