Coverage for waveqc/models.py: 54%

188 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-05-15 08:47 +0000

1import datetime 

2from collections.abc import Sequence 

3from itertools import groupby 

4from typing import Any 

5 

6import pendulum 

7from obspy.core.inventory import Network as ObspyNetwork 

8from obspy.core.inventory import Station as ObspyStation 

9from pendulum.date import Date as PendulumDate 

10from pyramid.config import Configurator 

11from pyramid.request import Request 

12from sqlalchemy import ( 

13 Date, 

14 DateTime, 

15 ForeignKey, 

16 Integer, 

17 MetaData, 

18 String, 

19 Subquery, 

20 Text, 

21 UniqueConstraint, 

22 delete, 

23 except_, 

24 extract, 

25 func, 

26 literal, 

27 not_, 

28 or_, 

29 select, 

30 update, 

31) 

32from sqlalchemy.dialects.postgresql import insert 

33from sqlalchemy.engine import create_engine 

34from sqlalchemy.orm import ( 

35 DeclarativeBase, 

36 Mapped, 

37 mapped_column, 

38 relationship, 

39 scoped_session, 

40) 

41from sqlalchemy.orm.session import Session, sessionmaker 

42from transaction import TransactionManager 

43from zope.sqlalchemy import register 

44 

45from .config import settings as waveqc_settings 

46 

47convention = { 

48 "ix": "ix_%(column_0_label)s", 

49 "uq": "uq_%(table_name)s_%(column_0_name)s", 

50 "ck": "ck_%(table_name)s_%(constraint_name)s", 

51 "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", 

52 "pk": "pk_%(table_name)s", 

53} 

54 

55 

56class Base(DeclarativeBase): 

57 metadata = MetaData(naming_convention=convention) 

58 

59 

60def get_tm_session( 

61 session_factory: sessionmaker[Session], 

62 transaction_manager: TransactionManager, 

63 request: Request = None, 

64) -> Session: 

65 dbsession = session_factory(info={"request": request}) 

66 register(dbsession, transaction_manager=transaction_manager) 

67 return dbsession 

68 

69 

70def includeme(config: Configurator) -> None: 

71 settings = config.get_settings() 

72 settings["tm.manager_hook"] = "pyramid_tm.explicit_manager" 

73 config.include("pyramid_tm") 

74 dbengine = settings.get("dbengine") 

75 if not dbengine: 

76 dbengine = create_engine(str(waveqc_settings.PG_DSN)) 

77 

78 session_factory = sessionmaker(bind=dbengine) 

79 config.registry["dbsession_factory"] = session_factory 

80 

81 def dbsession(request: Request) -> Any: # noqa: ANN401 

82 dbsession = request.environ.get("app.dbsession") 

83 if dbsession is None: 

84 dbsession = get_tm_session(session_factory, request.tm, request=request) 

85 return dbsession 

86 

87 config.add_request_method(dbsession, reify=True) 

88 

89 

90def checks_scope(dbsession: scoped_session[Session]) -> Subquery: 

91 start = dbsession.scalar(select(func.min(Check.date))) 

92 end = pendulum.yesterday("utc").date() 

93 delta = pendulum.duration(days=1) 

94 return select(func.generate_series(start, end, delta).label("day")).subquery() 

95 

96 

97class OperatedBy(Base): 

98 __tablename__ = "operated_by" 

99 

100 station_id: Mapped[int] = mapped_column(ForeignKey("station.id"), primary_key=True) 

101 operator_id: Mapped[int] = mapped_column( 

102 ForeignKey("operator.id"), primary_key=True 

103 ) 

104 

105 

106class Operator(Base): 

107 __tablename__ = "operator" 

108 

109 id: Mapped[int] = mapped_column(primary_key=True) 

110 agency: Mapped[str] = mapped_column(unique=True) 

111 website: Mapped[str | None] 

112 

113 stations: Mapped[list["Station"]] = relationship( 

114 secondary="operated_by", back_populates="operators" 

115 ) 

116 

117 def __repr__(self) -> str: 

118 return f"Operator({self.agency})" 

119 

120 

121class Network(Base): 

122 __tablename__ = "network" 

123 

124 id: Mapped[int] = mapped_column(primary_key=True) 

125 code: Mapped[str] = mapped_column(String(8), unique=True) 

126 description: Mapped[str | None] = mapped_column(Text()) 

127 

128 stations: Mapped[list["Station"]] = relationship( 

129 back_populates="network", 

130 cascade="all, delete", 

131 order_by="Station.code", 

132 ) 

133 

134 def __repr__(self: "Network") -> str: 

135 return f"Network(code={self.code})" 

136 

137 @classmethod 

138 def populate( 

139 cls: type["Network"], dbsession: scoped_session[Session], network: ObspyNetwork 

140 ) -> "Network": 

141 result = dbsession.scalars( 

142 select(cls).where(cls.code == network.code) 

143 ).one_or_none() 

144 if not result: 

145 result = cls(code=network.code, description=network.description) 

146 dbsession.add(result) 

147 dbsession.flush() 

148 return result 

149 

150 def populate_stations( 

151 self: "Network", 

152 dbsession: scoped_session[Session], 

153 stations: list[ObspyStation], 

154 ) -> None: 

155 values = [ 

156 { 

157 "code": station.code, 

158 "start_date": station.start_date.datetime, 

159 "end_date": None 

160 if station.end_date.datetime.replace(tzinfo=datetime.UTC) 

161 > pendulum.now("utc") 

162 else station.end_date.datetime, 

163 "network_id": self.id, 

164 } 

165 for station in stations 

166 ] 

167 

168 insert_stmt = insert(Station).values(values) 

169 update_stmt = { 

170 field.name: field 

171 for field in insert_stmt.excluded 

172 if field.name not in ["id", "triggered"] 

173 } 

174 dbsession.execute( 

175 insert_stmt.on_conflict_do_update( 

176 constraint="uq_station_code", set_=update_stmt 

177 ) 

178 ) 

179 

180 def link_operators_to_stations( 

181 self: "Network", 

182 dbsession: scoped_session[Session], 

183 stations: list[ObspyStation], 

184 ) -> None: 

185 for item in stations: 

186 station = dbsession.scalars( 

187 select(Station).where( 

188 Station.network_id == self.id, Station.code == item.code 

189 ) 

190 ).one() 

191 # Remove old operators from stations 

192 for operator in station.operators: 

193 agencies = [obspy_operator.agency for obspy_operator in item.operators] 

194 if operator.agency not in agencies: 

195 station.operators.remove(operator) 

196 # Add new operators to stations 

197 for operator in item.operators: 

198 agency = dbsession.scalars( 

199 select(Operator).where(Operator.agency == operator.agency) 

200 ).one() 

201 if agency not in station.operators: 

202 station.operators.append(agency) 

203 

204 @staticmethod 

205 def purge_orphaned_operators(dbsession: scoped_session[Session]) -> None: 

206 dbsession.execute(delete(Operator).where(not_(Operator.stations.any()))) 

207 

208 def populate_operators( 

209 self: "Network", 

210 dbsession: scoped_session[Session], 

211 stations: list[ObspyStation], 

212 ) -> None: 

213 operators = [ 

214 dict(item) 

215 for item in { 

216 tuple({"agency": operator.agency, "website": None}.items()) 

217 for station in stations 

218 for operator in station.operators 

219 } 

220 ] 

221 insert_stmt = insert(Operator).values(operators) 

222 update_stmt = { 

223 code.name: code for code in insert_stmt.excluded if code.name not in ["id"] 

224 } 

225 dbsession.execute( 

226 insert_stmt.on_conflict_do_update( 

227 constraint="uq_operator_agency", set_=update_stmt 

228 ) 

229 ) 

230 

231 def populate_channels( 

232 self: "Network", 

233 dbsession: scoped_session[Session], 

234 network: ObspyNetwork, 

235 stations: Sequence["Station"], 

236 ) -> None: 

237 waveqc_stations = {station.code: station.id for station in stations} 

238 channels = [ 

239 { 

240 "code": channel.code, 

241 "location": channel.location_code, 

242 "end_date": None 

243 if channel.end_date.datetime.replace(tzinfo=datetime.UTC) 

244 > pendulum.now("utc") 

245 else channel.end_date.datetime, 

246 "raw_end_date": channel.end_date.datetime, 

247 "station_id": waveqc_stations[station.code], 

248 } 

249 for station in network.stations 

250 for channel in station.channels 

251 ] 

252 

253 # Here, we filter only last epochs for each channel 

254 channels.sort( 

255 reverse=True, 

256 key=lambda x: ( 

257 x["station_id"], 

258 x["code"], 

259 x["location"], 

260 x["raw_end_date"], 

261 ), 

262 ) # Channels are now (reverse) sorted by nslc's and end_date 

263 last_channels = [ 

264 next(grouped) # we take only the first element (latest end_date) 

265 for _, grouped in groupby( 

266 channels, lambda x: (x["station_id"], x["code"], x["location"]) 

267 ) 

268 ] 

269 for channel in last_channels: 

270 del channel["raw_end_date"] 

271 

272 insert_stmt = insert(Channel).values(last_channels) 

273 update_stmt = { 

274 code.name: code for code in insert_stmt.excluded if code.name not in ["id"] 

275 } 

276 dbsession.execute( 

277 insert_stmt.on_conflict_do_update( 

278 constraint="uq_channel_code", set_=update_stmt 

279 ) 

280 ) 

281 

282 

283class Station(Base): 

284 __tablename__ = "station" 

285 __table_args__ = (UniqueConstraint("code", "network_id"),) 

286 

287 id: Mapped[int] = mapped_column(primary_key=True) 

288 code: Mapped[str] = mapped_column(String(8)) 

289 triggered: Mapped[bool] = mapped_column(default=False) 

290 

291 start_date: Mapped[datetime.datetime | None] = mapped_column(DateTime()) 

292 end_date: Mapped[datetime.datetime | None] = mapped_column(DateTime()) 

293 

294 network_id: Mapped[int] = mapped_column( 

295 ForeignKey("network.id", ondelete="CASCADE") 

296 ) 

297 network: Mapped["Network"] = relationship(back_populates="stations") 

298 channels: Mapped[list["Channel"]] = relationship( 

299 back_populates="station", 

300 cascade="all, delete", 

301 ) 

302 operators: Mapped[list["Operator"]] = relationship( 

303 secondary="operated_by", back_populates="stations" 

304 ) 

305 

306 def __repr__(self: "Station") -> str: 

307 return f"Station(code={self.code})" 

308 

309 

310class Channel(Base): 

311 __tablename__ = "channel" 

312 __table_args__ = (UniqueConstraint("code", "station_id", "location"),) 

313 

314 id: Mapped[int] = mapped_column(primary_key=True) 

315 code: Mapped[str] = mapped_column(String(3)) 

316 location: Mapped[str] = mapped_column(String(8)) 

317 

318 end_date: Mapped[datetime.datetime | None] = mapped_column(DateTime()) 

319 

320 station_id: Mapped[int] = mapped_column( 

321 ForeignKey("station.id", ondelete="CASCADE") 

322 ) 

323 station: Mapped["Station"] = relationship(back_populates="channels") 

324 checks: Mapped[list["Check"]] = relationship( 

325 back_populates="channel", 

326 cascade="all, delete", 

327 order_by="Check.date", 

328 ) 

329 

330 def __repr__(self: "Channel") -> str: 

331 return f"Channel(code={self.code}, location={self.location})" 

332 

333 @classmethod 

334 def fix_closed_channels_checks( 

335 cls: type["Channel"], dbsession: scoped_session[Session] 

336 ) -> None: 

337 result = RESULT_CHANNEL_CLOSED 

338 date_range = checks_scope(dbsession) 

339 # To find where checks for closed channels are : 

340 # - first, we select which cheks should be marked as closed 

341 # - then, we select actual checks for closed channels 

342 # - finally, we make a difference between theoric and actual checks 

343 theoric_closed = ( 

344 select(Channel.id, date_range.c.day) 

345 .join(date_range, literal(True)) # noqa: FBT003 

346 .where( 

347 Channel.end_date <= date_range.c.day, 

348 extract("year", Channel.end_date) == extract("year", date_range.c.day), 

349 ) 

350 ) 

351 actual_closed = select(Check.channel_id, Check.date).where( 

352 Check.result == result 

353 ) 

354 difference = except_(theoric_closed, actual_closed).subquery() 

355 to_close = dbsession.execute( 

356 select( 

357 Network.code, 

358 Station.code, 

359 Channel.location, 

360 Channel.code, 

361 difference.c.day, 

362 ) 

363 .select_from(Channel) 

364 .join(difference, Channel.id == difference.c.id) 

365 .join(Channel.station) 

366 .join(Station.network) 

367 ).all() 

368 for network, station, location, channel, day in to_close: 

369 nslc = f"{network}.{station}.{location}.{channel}" 

370 Channel().store_check_result(dbsession, nslc, day.date(), result) 

371 if to_close: 

372 dbsession.commit() 

373 

374 @classmethod 

375 def fix_missing_checks( 

376 cls: type["Channel"], dbsession: scoped_session[Session] 

377 ) -> None: 

378 date_range = checks_scope(dbsession) 

379 # To find where missing checks are : 

380 # - first, we select theoric checks (a check per day per channel) 

381 # - then, we select actual checks stored in db 

382 # - finally, we make a difference between theoric and actual checks 

383 theoric_checks = ( 

384 select(Channel.id, date_range.c.day) 

385 .join(date_range, literal(True)) # noqa: FBT003 

386 .where(or_(Channel.end_date == None, Channel.end_date > date_range.c.day)) # noqa: E711 

387 ) 

388 actual_checks = select(Check.channel_id, Check.date) 

389 difference = except_(theoric_checks, actual_checks).subquery() 

390 missing = dbsession.execute( 

391 select( 

392 Network.code, 

393 Station.code, 

394 Channel.location, 

395 Channel.code, 

396 difference.c.day, 

397 ) 

398 .select_from(Channel) 

399 .join(difference, Channel.id == difference.c.id) 

400 .join(Channel.station) 

401 .join(Station.network) 

402 ).all() 

403 for network, station, location, channel, day in missing: 

404 nslc = f"{network}.{station}.{location}.{channel}" 

405 Channel().store_check_result(dbsession, nslc, day.date()) 

406 if missing: 

407 dbsession.commit() 

408 

409 @classmethod 

410 def store_check_result( # noqa: PLR0913 

411 cls: type["Channel"], 

412 dbsession: scoped_session[Session], 

413 nslc: str, 

414 date: PendulumDate, 

415 result: int = 0, 

416 completeness: int = 0, 

417 trace_count: int = 0, 

418 shortest_trace: int = 0, 

419 ) -> None: 

420 _, station, location, channel = nslc.split(".") 

421 

422 channel_checked = dbsession.scalars( 

423 select(cls).where( 

424 cls.code == channel, 

425 cls.location == location, 

426 cls.station.has(Station.code == station), 

427 ) 

428 ).one() 

429 check = dbsession.scalars( 

430 select(Check).where( 

431 Check.channel_id == channel_checked.id, Check.date == date 

432 ) 

433 ).one_or_none() 

434 if check: 

435 dbsession.execute( 

436 update(Check) 

437 .where(Check.id == check.id) 

438 .values( 

439 result=result, 

440 retries=check.retries + 1, 

441 completeness=completeness, 

442 shortest_trace=shortest_trace, 

443 trace_count=trace_count, 

444 ) 

445 ) 

446 else: 

447 check = Check( 

448 channel=channel_checked, 

449 date=date, 

450 result=result, 

451 completeness=completeness, 

452 shortest_trace=shortest_trace, 

453 trace_count=trace_count, 

454 ) 

455 dbsession.add(check) 

456 dbsession.flush() 

457 

458 

459RESULT_NO_DATA = 0 

460RESULT_NOT_READABLE = 1 

461RESULT_DECONVOLUTION_FAILS = 2 

462RESULT_DECONVOLUTION_PASS = 3 

463RESULT_CHANNEL_CLOSED = 4 

464 

465RESULT_PONDERATION = { 

466 RESULT_NO_DATA: 50, 

467 RESULT_NOT_READABLE: 100, 

468 RESULT_DECONVOLUTION_FAILS: 150, 

469 RESULT_DECONVOLUTION_PASS: 255, 

470 RESULT_CHANNEL_CLOSED: 0, 

471} 

472 

473 

474class Check(Base): 

475 __tablename__ = "check" 

476 __table_args__ = (UniqueConstraint("channel_id", "date"),) 

477 

478 id: Mapped[int] = mapped_column(primary_key=True) 

479 channel_id: Mapped[int] = mapped_column( 

480 ForeignKey("channel.id", ondelete="CASCADE") 

481 ) 

482 date: Mapped[datetime.date] = mapped_column(Date()) 

483 result: Mapped[int] = mapped_column(Integer()) 

484 completeness: Mapped[int] = mapped_column(Integer()) 

485 retries: Mapped[int] = mapped_column(Integer(), default=0) 

486 shortest_trace: Mapped[int] = mapped_column(Integer()) 

487 trace_count: Mapped[int] = mapped_column(Integer()) 

488 

489 channel: Mapped["Channel"] = relationship(back_populates="checks") 

490 

491 def __repr__(self: "Check") -> str: 

492 return f"Check({self.channel.code}, {self.date}, {self.result})"