Coverage for waveqc/models.py: 54%
188 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-15 08:47 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-05-15 08:47 +0000
1import datetime
2from collections.abc import Sequence
3from itertools import groupby
4from typing import Any
6import pendulum
7from obspy.core.inventory import Network as ObspyNetwork
8from obspy.core.inventory import Station as ObspyStation
9from pendulum.date import Date as PendulumDate
10from pyramid.config import Configurator
11from pyramid.request import Request
12from sqlalchemy import (
13 Date,
14 DateTime,
15 ForeignKey,
16 Integer,
17 MetaData,
18 String,
19 Subquery,
20 Text,
21 UniqueConstraint,
22 delete,
23 except_,
24 extract,
25 func,
26 literal,
27 not_,
28 or_,
29 select,
30 update,
31)
32from sqlalchemy.dialects.postgresql import insert
33from sqlalchemy.engine import create_engine
34from sqlalchemy.orm import (
35 DeclarativeBase,
36 Mapped,
37 mapped_column,
38 relationship,
39 scoped_session,
40)
41from sqlalchemy.orm.session import Session, sessionmaker
42from transaction import TransactionManager
43from zope.sqlalchemy import register
45from .config import settings as waveqc_settings
47convention = {
48 "ix": "ix_%(column_0_label)s",
49 "uq": "uq_%(table_name)s_%(column_0_name)s",
50 "ck": "ck_%(table_name)s_%(constraint_name)s",
51 "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
52 "pk": "pk_%(table_name)s",
53}
56class Base(DeclarativeBase):
57 metadata = MetaData(naming_convention=convention)
60def get_tm_session(
61 session_factory: sessionmaker[Session],
62 transaction_manager: TransactionManager,
63 request: Request = None,
64) -> Session:
65 dbsession = session_factory(info={"request": request})
66 register(dbsession, transaction_manager=transaction_manager)
67 return dbsession
70def includeme(config: Configurator) -> None:
71 settings = config.get_settings()
72 settings["tm.manager_hook"] = "pyramid_tm.explicit_manager"
73 config.include("pyramid_tm")
74 dbengine = settings.get("dbengine")
75 if not dbengine:
76 dbengine = create_engine(str(waveqc_settings.PG_DSN))
78 session_factory = sessionmaker(bind=dbengine)
79 config.registry["dbsession_factory"] = session_factory
81 def dbsession(request: Request) -> Any: # noqa: ANN401
82 dbsession = request.environ.get("app.dbsession")
83 if dbsession is None:
84 dbsession = get_tm_session(session_factory, request.tm, request=request)
85 return dbsession
87 config.add_request_method(dbsession, reify=True)
90def checks_scope(dbsession: scoped_session[Session]) -> Subquery:
91 start = dbsession.scalar(select(func.min(Check.date)))
92 end = pendulum.yesterday("utc").date()
93 delta = pendulum.duration(days=1)
94 return select(func.generate_series(start, end, delta).label("day")).subquery()
97class OperatedBy(Base):
98 __tablename__ = "operated_by"
100 station_id: Mapped[int] = mapped_column(ForeignKey("station.id"), primary_key=True)
101 operator_id: Mapped[int] = mapped_column(
102 ForeignKey("operator.id"), primary_key=True
103 )
106class Operator(Base):
107 __tablename__ = "operator"
109 id: Mapped[int] = mapped_column(primary_key=True)
110 agency: Mapped[str] = mapped_column(unique=True)
111 website: Mapped[str | None]
113 stations: Mapped[list["Station"]] = relationship(
114 secondary="operated_by", back_populates="operators"
115 )
117 def __repr__(self) -> str:
118 return f"Operator({self.agency})"
121class Network(Base):
122 __tablename__ = "network"
124 id: Mapped[int] = mapped_column(primary_key=True)
125 code: Mapped[str] = mapped_column(String(8), unique=True)
126 description: Mapped[str | None] = mapped_column(Text())
128 stations: Mapped[list["Station"]] = relationship(
129 back_populates="network",
130 cascade="all, delete",
131 order_by="Station.code",
132 )
134 def __repr__(self: "Network") -> str:
135 return f"Network(code={self.code})"
137 @classmethod
138 def populate(
139 cls: type["Network"], dbsession: scoped_session[Session], network: ObspyNetwork
140 ) -> "Network":
141 result = dbsession.scalars(
142 select(cls).where(cls.code == network.code)
143 ).one_or_none()
144 if not result:
145 result = cls(code=network.code, description=network.description)
146 dbsession.add(result)
147 dbsession.flush()
148 return result
150 def populate_stations(
151 self: "Network",
152 dbsession: scoped_session[Session],
153 stations: list[ObspyStation],
154 ) -> None:
155 values = [
156 {
157 "code": station.code,
158 "start_date": station.start_date.datetime,
159 "end_date": None
160 if station.end_date.datetime.replace(tzinfo=datetime.UTC)
161 > pendulum.now("utc")
162 else station.end_date.datetime,
163 "network_id": self.id,
164 }
165 for station in stations
166 ]
168 insert_stmt = insert(Station).values(values)
169 update_stmt = {
170 field.name: field
171 for field in insert_stmt.excluded
172 if field.name not in ["id", "triggered"]
173 }
174 dbsession.execute(
175 insert_stmt.on_conflict_do_update(
176 constraint="uq_station_code", set_=update_stmt
177 )
178 )
180 def link_operators_to_stations(
181 self: "Network",
182 dbsession: scoped_session[Session],
183 stations: list[ObspyStation],
184 ) -> None:
185 for item in stations:
186 station = dbsession.scalars(
187 select(Station).where(
188 Station.network_id == self.id, Station.code == item.code
189 )
190 ).one()
191 # Remove old operators from stations
192 for operator in station.operators:
193 agencies = [obspy_operator.agency for obspy_operator in item.operators]
194 if operator.agency not in agencies:
195 station.operators.remove(operator)
196 # Add new operators to stations
197 for operator in item.operators:
198 agency = dbsession.scalars(
199 select(Operator).where(Operator.agency == operator.agency)
200 ).one()
201 if agency not in station.operators:
202 station.operators.append(agency)
204 @staticmethod
205 def purge_orphaned_operators(dbsession: scoped_session[Session]) -> None:
206 dbsession.execute(delete(Operator).where(not_(Operator.stations.any())))
208 def populate_operators(
209 self: "Network",
210 dbsession: scoped_session[Session],
211 stations: list[ObspyStation],
212 ) -> None:
213 operators = [
214 dict(item)
215 for item in {
216 tuple({"agency": operator.agency, "website": None}.items())
217 for station in stations
218 for operator in station.operators
219 }
220 ]
221 insert_stmt = insert(Operator).values(operators)
222 update_stmt = {
223 code.name: code for code in insert_stmt.excluded if code.name not in ["id"]
224 }
225 dbsession.execute(
226 insert_stmt.on_conflict_do_update(
227 constraint="uq_operator_agency", set_=update_stmt
228 )
229 )
231 def populate_channels(
232 self: "Network",
233 dbsession: scoped_session[Session],
234 network: ObspyNetwork,
235 stations: Sequence["Station"],
236 ) -> None:
237 waveqc_stations = {station.code: station.id for station in stations}
238 channels = [
239 {
240 "code": channel.code,
241 "location": channel.location_code,
242 "end_date": None
243 if channel.end_date.datetime.replace(tzinfo=datetime.UTC)
244 > pendulum.now("utc")
245 else channel.end_date.datetime,
246 "raw_end_date": channel.end_date.datetime,
247 "station_id": waveqc_stations[station.code],
248 }
249 for station in network.stations
250 for channel in station.channels
251 ]
253 # Here, we filter only last epochs for each channel
254 channels.sort(
255 reverse=True,
256 key=lambda x: (
257 x["station_id"],
258 x["code"],
259 x["location"],
260 x["raw_end_date"],
261 ),
262 ) # Channels are now (reverse) sorted by nslc's and end_date
263 last_channels = [
264 next(grouped) # we take only the first element (latest end_date)
265 for _, grouped in groupby(
266 channels, lambda x: (x["station_id"], x["code"], x["location"])
267 )
268 ]
269 for channel in last_channels:
270 del channel["raw_end_date"]
272 insert_stmt = insert(Channel).values(last_channels)
273 update_stmt = {
274 code.name: code for code in insert_stmt.excluded if code.name not in ["id"]
275 }
276 dbsession.execute(
277 insert_stmt.on_conflict_do_update(
278 constraint="uq_channel_code", set_=update_stmt
279 )
280 )
283class Station(Base):
284 __tablename__ = "station"
285 __table_args__ = (UniqueConstraint("code", "network_id"),)
287 id: Mapped[int] = mapped_column(primary_key=True)
288 code: Mapped[str] = mapped_column(String(8))
289 triggered: Mapped[bool] = mapped_column(default=False)
291 start_date: Mapped[datetime.datetime | None] = mapped_column(DateTime())
292 end_date: Mapped[datetime.datetime | None] = mapped_column(DateTime())
294 network_id: Mapped[int] = mapped_column(
295 ForeignKey("network.id", ondelete="CASCADE")
296 )
297 network: Mapped["Network"] = relationship(back_populates="stations")
298 channels: Mapped[list["Channel"]] = relationship(
299 back_populates="station",
300 cascade="all, delete",
301 )
302 operators: Mapped[list["Operator"]] = relationship(
303 secondary="operated_by", back_populates="stations"
304 )
306 def __repr__(self: "Station") -> str:
307 return f"Station(code={self.code})"
310class Channel(Base):
311 __tablename__ = "channel"
312 __table_args__ = (UniqueConstraint("code", "station_id", "location"),)
314 id: Mapped[int] = mapped_column(primary_key=True)
315 code: Mapped[str] = mapped_column(String(3))
316 location: Mapped[str] = mapped_column(String(8))
318 end_date: Mapped[datetime.datetime | None] = mapped_column(DateTime())
320 station_id: Mapped[int] = mapped_column(
321 ForeignKey("station.id", ondelete="CASCADE")
322 )
323 station: Mapped["Station"] = relationship(back_populates="channels")
324 checks: Mapped[list["Check"]] = relationship(
325 back_populates="channel",
326 cascade="all, delete",
327 order_by="Check.date",
328 )
330 def __repr__(self: "Channel") -> str:
331 return f"Channel(code={self.code}, location={self.location})"
333 @classmethod
334 def fix_closed_channels_checks(
335 cls: type["Channel"], dbsession: scoped_session[Session]
336 ) -> None:
337 result = RESULT_CHANNEL_CLOSED
338 date_range = checks_scope(dbsession)
339 # To find where checks for closed channels are :
340 # - first, we select which cheks should be marked as closed
341 # - then, we select actual checks for closed channels
342 # - finally, we make a difference between theoric and actual checks
343 theoric_closed = (
344 select(Channel.id, date_range.c.day)
345 .join(date_range, literal(True)) # noqa: FBT003
346 .where(
347 Channel.end_date <= date_range.c.day,
348 extract("year", Channel.end_date) == extract("year", date_range.c.day),
349 )
350 )
351 actual_closed = select(Check.channel_id, Check.date).where(
352 Check.result == result
353 )
354 difference = except_(theoric_closed, actual_closed).subquery()
355 to_close = dbsession.execute(
356 select(
357 Network.code,
358 Station.code,
359 Channel.location,
360 Channel.code,
361 difference.c.day,
362 )
363 .select_from(Channel)
364 .join(difference, Channel.id == difference.c.id)
365 .join(Channel.station)
366 .join(Station.network)
367 ).all()
368 for network, station, location, channel, day in to_close:
369 nslc = f"{network}.{station}.{location}.{channel}"
370 Channel().store_check_result(dbsession, nslc, day.date(), result)
371 if to_close:
372 dbsession.commit()
374 @classmethod
375 def fix_missing_checks(
376 cls: type["Channel"], dbsession: scoped_session[Session]
377 ) -> None:
378 date_range = checks_scope(dbsession)
379 # To find where missing checks are :
380 # - first, we select theoric checks (a check per day per channel)
381 # - then, we select actual checks stored in db
382 # - finally, we make a difference between theoric and actual checks
383 theoric_checks = (
384 select(Channel.id, date_range.c.day)
385 .join(date_range, literal(True)) # noqa: FBT003
386 .where(or_(Channel.end_date == None, Channel.end_date > date_range.c.day)) # noqa: E711
387 )
388 actual_checks = select(Check.channel_id, Check.date)
389 difference = except_(theoric_checks, actual_checks).subquery()
390 missing = dbsession.execute(
391 select(
392 Network.code,
393 Station.code,
394 Channel.location,
395 Channel.code,
396 difference.c.day,
397 )
398 .select_from(Channel)
399 .join(difference, Channel.id == difference.c.id)
400 .join(Channel.station)
401 .join(Station.network)
402 ).all()
403 for network, station, location, channel, day in missing:
404 nslc = f"{network}.{station}.{location}.{channel}"
405 Channel().store_check_result(dbsession, nslc, day.date())
406 if missing:
407 dbsession.commit()
409 @classmethod
410 def store_check_result( # noqa: PLR0913
411 cls: type["Channel"],
412 dbsession: scoped_session[Session],
413 nslc: str,
414 date: PendulumDate,
415 result: int = 0,
416 completeness: int = 0,
417 trace_count: int = 0,
418 shortest_trace: int = 0,
419 ) -> None:
420 _, station, location, channel = nslc.split(".")
422 channel_checked = dbsession.scalars(
423 select(cls).where(
424 cls.code == channel,
425 cls.location == location,
426 cls.station.has(Station.code == station),
427 )
428 ).one()
429 check = dbsession.scalars(
430 select(Check).where(
431 Check.channel_id == channel_checked.id, Check.date == date
432 )
433 ).one_or_none()
434 if check:
435 dbsession.execute(
436 update(Check)
437 .where(Check.id == check.id)
438 .values(
439 result=result,
440 retries=check.retries + 1,
441 completeness=completeness,
442 shortest_trace=shortest_trace,
443 trace_count=trace_count,
444 )
445 )
446 else:
447 check = Check(
448 channel=channel_checked,
449 date=date,
450 result=result,
451 completeness=completeness,
452 shortest_trace=shortest_trace,
453 trace_count=trace_count,
454 )
455 dbsession.add(check)
456 dbsession.flush()
459RESULT_NO_DATA = 0
460RESULT_NOT_READABLE = 1
461RESULT_DECONVOLUTION_FAILS = 2
462RESULT_DECONVOLUTION_PASS = 3
463RESULT_CHANNEL_CLOSED = 4
465RESULT_PONDERATION = {
466 RESULT_NO_DATA: 50,
467 RESULT_NOT_READABLE: 100,
468 RESULT_DECONVOLUTION_FAILS: 150,
469 RESULT_DECONVOLUTION_PASS: 255,
470 RESULT_CHANNEL_CLOSED: 0,
471}
474class Check(Base):
475 __tablename__ = "check"
476 __table_args__ = (UniqueConstraint("channel_id", "date"),)
478 id: Mapped[int] = mapped_column(primary_key=True)
479 channel_id: Mapped[int] = mapped_column(
480 ForeignKey("channel.id", ondelete="CASCADE")
481 )
482 date: Mapped[datetime.date] = mapped_column(Date())
483 result: Mapped[int] = mapped_column(Integer())
484 completeness: Mapped[int] = mapped_column(Integer())
485 retries: Mapped[int] = mapped_column(Integer(), default=0)
486 shortest_trace: Mapped[int] = mapped_column(Integer())
487 trace_count: Mapped[int] = mapped_column(Integer())
489 channel: Mapped["Channel"] = relationship(back_populates="checks")
491 def __repr__(self: "Check") -> str:
492 return f"Check({self.channel.code}, {self.date}, {self.result})"