|
20 | 20 | Iterable, |
21 | 21 | List, |
22 | 22 | Optional, |
| 23 | + Set, |
23 | 24 | Tuple, |
24 | 25 | Union, |
25 | 26 | cast, |
@@ -454,106 +455,175 @@ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]: |
454 | 455 | } |
455 | 456 |
|
456 | 457 | @cached() |
457 | | - async def get_thread_summary( |
458 | | - self, event_id: str, room_id: str |
459 | | - ) -> Tuple[int, Optional[EventBase]]: |
| 458 | + def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: |
| 459 | + raise NotImplementedError() |
| 460 | + |
| 461 | + @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") |
| 462 | + async def _get_thread_summaries( |
| 463 | + self, event_ids: Collection[str] |
| 464 | + ) -> Dict[str, Optional[Tuple[int, EventBase]]]: |
460 | 465 | """Get the number of threaded replies and the latest reply (if any) for the given event. |
461 | 466 |
|
462 | 467 | Args: |
463 | | - event_id: Summarize the thread related to this event ID. |
464 | | - room_id: The room the event belongs to. |
| 468 | + event_ids: Summarize the thread related to this event ID. |
465 | 469 |
|
466 | 470 | Returns: |
467 | | - The number of items in the thread and the most recent response, if any. |
| 471 | + A map of the thread summary each event. A missing event implies there |
| 472 | + are no threaded replies. |
| 473 | +
|
| 474 | + Each summary includes the number of items in the thread and the most |
| 475 | + recent response. |
468 | 476 | """ |
469 | 477 |
|
470 | | - def _get_thread_summary_txn( |
| 478 | + def _get_thread_summaries_txn( |
471 | 479 | txn: LoggingTransaction, |
472 | | - ) -> Tuple[int, Optional[str]]: |
473 | | - # Fetch the latest event ID in the thread. |
| 480 | + ) -> Tuple[Dict[str, int], Dict[str, str]]: |
| 481 | + # Fetch the count of threaded events and the latest event ID. |
474 | 482 | # TODO Should this only allow m.room.message events. |
475 | | - sql = """ |
476 | | - SELECT event_id |
477 | | - FROM event_relations |
478 | | - INNER JOIN events USING (event_id) |
479 | | - WHERE |
480 | | - relates_to_id = ? |
481 | | - AND room_id = ? |
482 | | - AND relation_type = ? |
483 | | - ORDER BY topological_ordering DESC, stream_ordering DESC |
484 | | - LIMIT 1 |
485 | | - """ |
| 483 | + if isinstance(self.database_engine, PostgresEngine): |
| 484 | + # The `DISTINCT ON` clause will pick the *first* row it encounters, |
| 485 | + # so ordering by topologica ordering + stream ordering desc will |
| 486 | + # ensure we get the latest event in the thread. |
| 487 | + sql = """ |
| 488 | + SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child |
| 489 | + INNER JOIN event_relations USING (event_id) |
| 490 | + INNER JOIN events AS parent ON |
| 491 | + parent.event_id = relates_to_id |
| 492 | + AND parent.room_id = child.room_id |
| 493 | + WHERE |
| 494 | + %s |
| 495 | + AND relation_type = ? |
| 496 | + ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC |
| 497 | + """ |
| 498 | + else: |
| 499 | + # SQLite uses a simplified query which returns all entries for a |
| 500 | + # thread. The first result for each thread is chosen to and subsequent |
| 501 | + # results for a thread are ignored. |
| 502 | + sql = """ |
| 503 | + SELECT parent.event_id, child.event_id FROM events AS child |
| 504 | + INNER JOIN event_relations USING (event_id) |
| 505 | + INNER JOIN events AS parent ON |
| 506 | + parent.event_id = relates_to_id |
| 507 | + AND parent.room_id = child.room_id |
| 508 | + WHERE |
| 509 | + %s |
| 510 | + AND relation_type = ? |
| 511 | + ORDER BY child.topological_ordering DESC, child.stream_ordering DESC |
| 512 | + """ |
| 513 | + |
| 514 | + clause, args = make_in_list_sql_clause( |
| 515 | + txn.database_engine, "relates_to_id", event_ids |
| 516 | + ) |
| 517 | + args.append(RelationTypes.THREAD) |
486 | 518 |
|
487 | | - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) |
488 | | - row = txn.fetchone() |
489 | | - if row is None: |
490 | | - return 0, None |
| 519 | + txn.execute(sql % (clause,), args) |
| 520 | + latest_event_ids = {} |
| 521 | + for parent_event_id, child_event_id in txn: |
| 522 | + # Only consider the latest threaded reply (by topological ordering). |
| 523 | + if parent_event_id not in latest_event_ids: |
| 524 | + latest_event_ids[parent_event_id] = child_event_id |
491 | 525 |
|
492 | | - latest_event_id = row[0] |
| 526 | + # If no threads were found, bail. |
| 527 | + if not latest_event_ids: |
| 528 | + return {}, latest_event_ids |
493 | 529 |
|
494 | 530 | # Fetch the number of threaded replies. |
495 | 531 | sql = """ |
496 | | - SELECT COUNT(event_id) |
497 | | - FROM event_relations |
498 | | - INNER JOIN events USING (event_id) |
| 532 | + SELECT parent.event_id, COUNT(child.event_id) FROM events AS child |
| 533 | + INNER JOIN event_relations USING (event_id) |
| 534 | + INNER JOIN events AS parent ON |
| 535 | + parent.event_id = relates_to_id |
| 536 | + AND parent.room_id = child.room_id |
499 | 537 | WHERE |
500 | | - relates_to_id = ? |
501 | | - AND room_id = ? |
| 538 | + %s |
502 | 539 | AND relation_type = ? |
| 540 | + GROUP BY parent.event_id |
503 | 541 | """ |
504 | | - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) |
505 | | - count = cast(Tuple[int], txn.fetchone())[0] |
506 | 542 |
|
507 | | - return count, latest_event_id |
| 543 | + # Regenerate the arguments since only threads found above could |
| 544 | + # possibly have any replies. |
| 545 | + clause, args = make_in_list_sql_clause( |
| 546 | + txn.database_engine, "relates_to_id", latest_event_ids.keys() |
| 547 | + ) |
| 548 | + args.append(RelationTypes.THREAD) |
| 549 | + |
| 550 | + txn.execute(sql % (clause,), args) |
| 551 | + counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) |
508 | 552 |
|
509 | | - count, latest_event_id = await self.db_pool.runInteraction( |
510 | | - "get_thread_summary", _get_thread_summary_txn |
| 553 | + return counts, latest_event_ids |
| 554 | + |
| 555 | + counts, latest_event_ids = await self.db_pool.runInteraction( |
| 556 | + "get_thread_summaries", _get_thread_summaries_txn |
511 | 557 | ) |
512 | 558 |
|
513 | | - latest_event = None |
514 | | - if latest_event_id: |
515 | | - latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] |
| 559 | + latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] |
| 560 | + |
| 561 | + # Map to the event IDs to the thread summary. |
| 562 | + # |
| 563 | + # There might not be a summary due to there not being a thread or |
| 564 | + # due to the latest event not being known, either case is treated the same. |
| 565 | + summaries = {} |
| 566 | + for parent_event_id, latest_event_id in latest_event_ids.items(): |
| 567 | + latest_event = latest_events.get(latest_event_id) |
| 568 | + |
| 569 | + summary = None |
| 570 | + if latest_event: |
| 571 | + summary = (counts[parent_event_id], latest_event) |
| 572 | + summaries[parent_event_id] = summary |
516 | 573 |
|
517 | | - return count, latest_event |
| 574 | + return summaries |
518 | 575 |
|
519 | 576 | @cached() |
520 | | - async def get_thread_participated( |
521 | | - self, event_id: str, room_id: str, user_id: str |
522 | | - ) -> bool: |
523 | | - """Get whether the requesting user participated in a thread. |
| 577 | + def get_thread_participated(self, event_id: str, user_id: str) -> bool: |
| 578 | + raise NotImplementedError() |
524 | 579 |
|
525 | | - This is separate from get_thread_summary since that can be cached across |
526 | | - all users while this value is specific to the requeser. |
| 580 | + @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") |
| 581 | + async def _get_threads_participated( |
| 582 | + self, event_ids: Collection[str], user_id: str |
| 583 | + ) -> Dict[str, bool]: |
| 584 | + """Get whether the requesting user participated in the given threads. |
| 585 | +
|
| 586 | + This is separate from get_thread_summaries since that can be cached across |
| 587 | + all users while this value is specific to the requester. |
527 | 588 |
|
528 | 589 | Args: |
529 | | - event_id: The thread related to this event ID. |
530 | | - room_id: The room the event belongs to. |
| 590 | + event_ids: The thread related to these event IDs. |
531 | 591 | user_id: The user requesting the summary. |
532 | 592 |
|
533 | 593 | Returns: |
534 | | - True if the requesting user participated in the thread, otherwise false. |
| 594 | + A map of event ID to a boolean which represents if the requesting |
| 595 | + user participated in that event's thread, otherwise false. |
535 | 596 | """ |
536 | 597 |
|
537 | | - def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: |
| 598 | + def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]: |
538 | 599 | # Fetch whether the requester has participated or not. |
539 | 600 | sql = """ |
540 | | - SELECT 1 |
541 | | - FROM event_relations |
542 | | - INNER JOIN events USING (event_id) |
| 601 | + SELECT DISTINCT relates_to_id |
| 602 | + FROM events AS child |
| 603 | + INNER JOIN event_relations USING (event_id) |
| 604 | + INNER JOIN events AS parent ON |
| 605 | + parent.event_id = relates_to_id |
| 606 | + AND parent.room_id = child.room_id |
543 | 607 | WHERE |
544 | | - relates_to_id = ? |
545 | | - AND room_id = ? |
| 608 | + %s |
546 | 609 | AND relation_type = ? |
547 | | - AND sender = ? |
| 610 | + AND child.sender = ? |
548 | 611 | """ |
549 | 612 |
|
550 | | - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) |
551 | | - return bool(txn.fetchone()) |
| 613 | + clause, args = make_in_list_sql_clause( |
| 614 | + txn.database_engine, "relates_to_id", event_ids |
| 615 | + ) |
| 616 | + args.extend((RelationTypes.THREAD, user_id)) |
552 | 617 |
|
553 | | - return await self.db_pool.runInteraction( |
| 618 | + txn.execute(sql % (clause,), args) |
| 619 | + return {row[0] for row in txn.fetchall()} |
| 620 | + |
| 621 | + participated_threads = await self.db_pool.runInteraction( |
554 | 622 | "get_thread_summary", _get_thread_summary_txn |
555 | 623 | ) |
556 | 624 |
|
| 625 | + return {event_id: event_id in participated_threads for event_id in event_ids} |
| 626 | + |
557 | 627 | async def events_have_relations( |
558 | 628 | self, |
559 | 629 | parent_ids: List[str], |
@@ -700,21 +770,6 @@ async def _get_bundled_aggregation_for_event( |
700 | 770 | if references.chunk: |
701 | 771 | aggregations.references = await references.to_dict(cast("DataStore", self)) |
702 | 772 |
|
703 | | - # If this event is the start of a thread, include a summary of the replies. |
704 | | - if self._msc3440_enabled: |
705 | | - thread_count, latest_thread_event = await self.get_thread_summary( |
706 | | - event_id, room_id |
707 | | - ) |
708 | | - participated = await self.get_thread_participated( |
709 | | - event_id, room_id, user_id |
710 | | - ) |
711 | | - if latest_thread_event: |
712 | | - aggregations.thread = _ThreadAggregation( |
713 | | - latest_event=latest_thread_event, |
714 | | - count=thread_count, |
715 | | - current_user_participated=participated, |
716 | | - ) |
717 | | - |
718 | 773 | # Store the bundled aggregations in the event metadata for later use. |
719 | 774 | return aggregations |
720 | 775 |
|
@@ -763,6 +818,27 @@ async def get_bundled_aggregations( |
763 | 818 | for event_id, edit in edits.items(): |
764 | 819 | results.setdefault(event_id, BundledAggregations()).replace = edit |
765 | 820 |
|
| 821 | + # Fetch thread summaries. |
| 822 | + if self._msc3440_enabled: |
| 823 | + summaries = await self._get_thread_summaries(seen_event_ids) |
| 824 | + # Only fetch participated for a limited selection based on what had |
| 825 | + # summaries. |
| 826 | + participated = await self._get_threads_participated( |
| 827 | + summaries.keys(), user_id |
| 828 | + ) |
| 829 | + for event_id, summary in summaries.items(): |
| 830 | + if summary: |
| 831 | + thread_count, latest_thread_event = summary |
| 832 | + results.setdefault( |
| 833 | + event_id, BundledAggregations() |
| 834 | + ).thread = _ThreadAggregation( |
| 835 | + latest_event=latest_thread_event, |
| 836 | + count=thread_count, |
| 837 | + # If there's a thread summary it must also exist in the |
| 838 | + # participated dictionary. |
| 839 | + current_user_participated=participated[event_id], |
| 840 | + ) |
| 841 | + |
766 | 842 | return results |
767 | 843 |
|
768 | 844 |
|
|
0 commit comments