diff --git a/deployment/migrations/versions/0034_8ece21fbeb47_balance_tracker.py b/deployment/migrations/versions/0034_8ece21fbeb47_balance_tracker.py new file mode 100644 index 000000000..5d3db363f --- /dev/null +++ b/deployment/migrations/versions/0034_8ece21fbeb47_balance_tracker.py @@ -0,0 +1,49 @@ +"""empty message + +Revision ID: 8ece21fbeb47 +Revises: 1c06d0ade60c +Create Date: 2025-03-18 09:58:57.469799 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import func + + +# revision identifiers, used by Alembic. +revision = '8ece21fbeb47' +down_revision = '1c06d0ade60c' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "balances", sa.Column("last_update", sa.TIMESTAMP( + timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()) + ) + + op.create_table( + "cron_jobs", + sa.Column("id", sa.String(), nullable=False), + # Interval is specified in seconds + sa.Column("interval", sa.Integer(), nullable=False, default=24), + sa.Column("last_run", sa.TIMESTAMP(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + + op.execute( + """ + INSERT INTO cron_jobs(id, interval, last_run) VALUES ('balance', 3600, '2025-01-01 00:00:00') + """ + ) + + pass + + +def downgrade() -> None: + op.drop_column("balances", "last_update") + + op.drop_table("cron_jobs") + + pass diff --git a/src/aleph/commands.py b/src/aleph/commands.py index 31c8878e5..062b65a8e 100644 --- a/src/aleph/commands.py +++ b/src/aleph/commands.py @@ -29,6 +29,8 @@ from aleph.db.connection import make_db_url, make_engine, make_session_factory from aleph.exceptions import InvalidConfigException, KeyNotFoundException from aleph.jobs import start_jobs +from aleph.jobs.cron.balance_job import BalanceCronJob +from aleph.jobs.cron.cron_job import CronJob, cron_job_task from aleph.network import listener_tasks from aleph.services import p2p from aleph.services.cache.materialized_views import refresh_cache_materialized_views @@ -147,6 +149,10 @@ async def main(args: List[str]) -> None: garbage_collector = GarbageCollector( session_factory=session_factory, storage_service=storage_service ) + cron_job = CronJob( + session_factory=session_factory, + jobs={"balance": BalanceCronJob(session_factory=session_factory)}, + ) chain_data_service = ChainDataService( session_factory=session_factory, storage_service=storage_service, @@ -203,6 +209,10 @@ async def main(args: List[str]) -> None: ) LOGGER.debug("Initialized garbage collector task") + LOGGER.debug("Initializing cron job task") + tasks.append(cron_job_task(config=config, cron_job=cron_job)) + LOGGER.debug("Initialized cron job task") + LOGGER.debug("Running event loop") await asyncio.gather(*tasks) diff --git a/src/aleph/config.py b/src/aleph/config.py index 97464805f..89a860776 100644 --- a/src/aleph/config.py +++ b/src/aleph/config.py @@ -38,6 +38,10 @@ def get_defaults(): # Maximum number of chain/sync events processed at the same time. "max_concurrency": 20, }, + "cron": { + # Interval between cron job trackers runs, expressed in hours. + "period": 0.5, # 30 mins + }, }, "cache": { "ttl": { diff --git a/src/aleph/db/accessors/balances.py b/src/aleph/db/accessors/balances.py index e1e9b1eed..225274cb5 100644 --- a/src/aleph/db/accessors/balances.py +++ b/src/aleph/db/accessors/balances.py @@ -1,3 +1,4 @@ +import datetime as dt from decimal import Decimal from io import StringIO from typing import Dict, Mapping, Optional, Sequence @@ -7,6 +8,7 @@ from sqlalchemy.sql import Select from aleph.db.models import AlephBalanceDb +from aleph.toolkit.timestamp import utc_now from aleph.types.db_session import DbSession @@ -140,6 +142,8 @@ def update_balances( table from the temporary one. """ + last_update = utc_now() + session.execute( "CREATE TEMPORARY TABLE temp_balances AS SELECT * FROM balances WITH NO DATA" # type: ignore[arg-type] ) @@ -151,21 +155,21 @@ def update_balances( csv_balances = StringIO( "\n".join( [ - f"{address};{chain.value};{dapp or ''};{balance};{eth_height}" + f"{address};{chain.value};{dapp or ''};{balance};{eth_height};{last_update}" for address, balance in balances.items() ] ) ) cursor.copy_expert( - "COPY temp_balances(address, chain, dapp, balance, eth_height) FROM STDIN WITH CSV DELIMITER ';'", + "COPY temp_balances(address, chain, dapp, balance, eth_height, last_update) FROM STDIN WITH CSV DELIMITER ';'", csv_balances, ) session.execute( """ - INSERT INTO balances(address, chain, dapp, balance, eth_height) - (SELECT address, chain, dapp, balance, eth_height FROM temp_balances) + INSERT INTO balances(address, chain, dapp, balance, eth_height, last_update) + (SELECT address, chain, dapp, balance, eth_height, last_update FROM temp_balances) ON CONFLICT ON CONSTRAINT balances_address_chain_dapp_uindex DO UPDATE - SET balance = excluded.balance, eth_height = excluded.eth_height + SET balance = excluded.balance, eth_height = excluded.eth_height, last_update = (CASE WHEN excluded.balance <> balances.balance THEN excluded.last_update ELSE balances.last_update END) WHERE excluded.eth_height > balances.eth_height """ # type: ignore[arg-type] ) @@ -174,3 +178,12 @@ def update_balances( # tends to reuse connections. Dropping the table here guarantees it will not be present # on the next run. session.execute("DROP TABLE temp_balances") # type: ignore[arg-type] + + +def get_updated_balance_accounts(session: DbSession, last_update: dt.datetime): + select_stmt = ( + select(AlephBalanceDb.address) + .where(AlephBalanceDb.last_update >= last_update) + .distinct() + ) + return (session.execute(select_stmt)).scalars().all() diff --git a/src/aleph/db/accessors/cost.py b/src/aleph/db/accessors/cost.py index 8f51e433d..f7d7f8c44 100644 --- a/src/aleph/db/accessors/cost.py +++ b/src/aleph/db/accessors/cost.py @@ -2,13 +2,16 @@ from typing import Iterable, List, Optional from aleph_message.models import PaymentType -from sqlalchemy import func, select +from sqlalchemy import asc, delete, func, select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.sql import Insert +from aleph.db.models import ChainTxDb, message_confirmations from aleph.db.models.account_costs import AccountCostsDb +from aleph.db.models.messages import MessageStatusDb from aleph.toolkit.costs import format_cost from aleph.types.db_session import DbSession +from aleph.types.message_status import MessageStatus def get_total_cost_for_address( @@ -35,6 +38,40 @@ def get_total_cost_for_address( return format_cost(Decimal(total_cost or 0)) +def get_total_costs_for_address_grouped_by_message( + session: DbSession, + address: str, + payment_type: Optional[PaymentType] = PaymentType.hold, +): + total_prop = ( + AccountCostsDb.cost_hold + if payment_type == PaymentType.hold + else AccountCostsDb.cost_stream + ) + + id_field = func.min(AccountCostsDb.id) + + select_stmt = ( + select( + AccountCostsDb.item_hash, ChainTxDb.height, func.sum(total_prop), id_field + ) + .select_from(AccountCostsDb) + .join( + message_confirmations, + message_confirmations.c.item_hash == AccountCostsDb.item_hash, + ) + .join(ChainTxDb, message_confirmations.c.tx_hash == ChainTxDb.hash) + .where( + (AccountCostsDb.owner == address) + & (AccountCostsDb.payment_type == payment_type) + ) + .group_by(AccountCostsDb.item_hash, ChainTxDb.height) + .order_by(asc(id_field)) + ) + + return (session.execute(select_stmt)).all() + + def get_message_costs(session: DbSession, item_hash: str) -> Iterable[AccountCostsDb]: select_stmt = select(AccountCostsDb).where(AccountCostsDb.item_hash == item_hash) return (session.execute(select_stmt)).scalars().all() @@ -59,3 +96,21 @@ def make_costs_upsert_query(costs: List[AccountCostsDb]) -> Insert: "cost_stream": upsert_stmt.excluded.cost_stream, }, ) + + +def delete_costs_for_message(session: DbSession, item_hash: str) -> None: + delete_stmt = delete(AccountCostsDb).where(AccountCostsDb.item_hash == item_hash) + session.execute(delete_stmt) + + +def delete_costs_for_forgotten_and_deleted_messages(session: DbSession) -> None: + delete_stmt = ( + delete(AccountCostsDb) + .where(AccountCostsDb.item_hash == MessageStatusDb.item_hash) + .where( + (MessageStatusDb.status == MessageStatus.FORGOTTEN) + | (MessageStatusDb.status == MessageStatus.REMOVED) + ) + .execution_options(synchronize_session=False) + ) + session.execute(delete_stmt) diff --git a/src/aleph/db/accessors/cron_jobs.py b/src/aleph/db/accessors/cron_jobs.py new file mode 100644 index 000000000..57604e67a --- /dev/null +++ b/src/aleph/db/accessors/cron_jobs.py @@ -0,0 +1,31 @@ +import datetime as dt +from typing import List, Optional + +from sqlalchemy import delete, select, update + +from aleph.db.models.cron_jobs import CronJobDb +from aleph.types.db_session import DbSession + + +def get_cron_jobs(session: DbSession) -> List[CronJobDb]: + select_stmt = select(CronJobDb) + + return (session.execute(select_stmt)).scalars().all() + + +def get_cron_job(session: DbSession, id: str) -> Optional[CronJobDb]: + select_stmt = select(CronJobDb).where(CronJobDb.id == id) + + return (session.execute(select_stmt)).scalar_one_or_none() + + +def update_cron_job(session: DbSession, id: str, last_run: dt.datetime) -> None: + update_stmt = update(CronJobDb).values(last_run=last_run).where(CronJobDb.id == id) + + session.execute(update_stmt) + + +def delete_cron_job(session: DbSession, id: str) -> None: + delete_stmt = delete(CronJobDb).where(CronJobDb.id == id) + + session.execute(delete_stmt) diff --git a/src/aleph/db/accessors/files.py b/src/aleph/db/accessors/files.py index a073098e2..3c67a6b4e 100644 --- a/src/aleph/db/accessors/files.py +++ b/src/aleph/db/accessors/files.py @@ -1,5 +1,5 @@ import datetime as dt -from typing import Collection, Iterable, Optional, Tuple +from typing import Collection, Iterable, Optional, Tuple, Union from sqlalchemy import delete, func, select from sqlalchemy.dialects.postgresql import insert @@ -8,6 +8,7 @@ from aleph.types.db_session import DbSession from aleph.types.files import FileTag, FileType from aleph.types.sort_order import SortOrder +from aleph.utils import make_file_tag from ..models.files import ( ContentFilePinDb, @@ -112,9 +113,15 @@ def insert_grace_period_file_pin( file_hash: str, created: dt.datetime, delete_by: dt.datetime, + item_hash: Optional[str] = None, + owner: Optional[str] = None, + ref: Optional[str] = None, ) -> None: insert_stmt = insert(GracePeriodFilePinDb).values( + item_hash=item_hash, file_hash=file_hash, + owner=owner, + ref=ref, created=created, type=FilePinType.GRACE_PERIOD, delete_by=delete_by, @@ -122,6 +129,76 @@ def insert_grace_period_file_pin( session.execute(insert_stmt) +# TODO: Improve performance +def update_file_pin_grace_period( + session: DbSession, + item_hash: str, + delete_by: Union[dt.datetime, None], +) -> None: + if delete_by is None: + delete_stmt = ( + delete(GracePeriodFilePinDb) + .where(GracePeriodFilePinDb.item_hash == item_hash) + .returning( + GracePeriodFilePinDb.file_hash, + GracePeriodFilePinDb.owner, + GracePeriodFilePinDb.ref, + GracePeriodFilePinDb.created, + ) + ) + + grace_period = session.execute(delete_stmt).first() + if grace_period is None: + return + + file_hash, owner, ref, created = grace_period + + insert_message_file_pin( + session=session, + item_hash=item_hash, + file_hash=file_hash, + owner=owner, + ref=ref, + created=created, + ) + else: + delete_stmt = ( + delete(MessageFilePinDb) + .where(MessageFilePinDb.item_hash == item_hash) + .returning( + MessageFilePinDb.file_hash, + MessageFilePinDb.owner, + MessageFilePinDb.ref, + MessageFilePinDb.created, + ) + ) + + message_pin = session.execute(delete_stmt).first() + if message_pin is None: + return + + file_hash, owner, ref, created = message_pin + + insert_grace_period_file_pin( + session=session, + item_hash=item_hash, + file_hash=file_hash, + owner=owner, + ref=ref, + created=created, + delete_by=delete_by, + ) + + refresh_file_tag( + session=session, + tag=make_file_tag( + owner=owner, + ref=ref, + item_hash=item_hash, + ), + ) + + def delete_grace_period_file_pins(session: DbSession, datetime: dt.datetime) -> None: delete_stmt = delete(GracePeriodFilePinDb).where( GracePeriodFilePinDb.delete_by < datetime @@ -213,6 +290,10 @@ def get_file_tag(session: DbSession, tag: FileTag) -> Optional[FileTagDb]: return session.execute(select_stmt).scalar() +def file_pin_exists(session: DbSession, item_hash: str) -> bool: + return FilePinDb.exists(session=session, where=FilePinDb.item_hash == item_hash) + + def file_tag_exists(session: DbSession, tag: FileTag) -> bool: return FileTagDb.exists(session=session, where=FileTagDb.tag == tag) diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index 6ab4a6175..f75a81fe7 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -9,6 +9,7 @@ from sqlalchemy.sql import Insert, Select from sqlalchemy.sql.elements import literal +from aleph.db.accessors.cost import delete_costs_for_message from aleph.toolkit.timestamp import coerce_to_datetime, utc_now from aleph.types.channel import Channel from aleph.types.db_session import DbSession @@ -49,6 +50,13 @@ def message_exists(session: DbSession, item_hash: str) -> bool: ) +def get_one_message_by_item_hash( + session: DbSession, item_hash: str +) -> Optional[RejectedMessageDb]: + select_stmt = select(MessageDb).where(MessageDb.item_hash == item_hash) + return session.execute(select_stmt).scalar_one_or_none() + + def make_matching_messages_query( hashes: Optional[Sequence[ItemHash]] = None, addresses: Optional[Sequence[str]] = None, @@ -413,6 +421,11 @@ def forget_message( ) session.execute(delete(MessageDb).where(MessageDb.item_hash == item_hash)) + delete_costs_for_message( + session=session, + item_hash=item_hash, + ) + def append_to_forgotten_by( session: DbSession, forgotten_message_hash: str, forget_message_hash: str diff --git a/src/aleph/db/models/balances.py b/src/aleph/db/models/balances.py index feff59d89..b80479d45 100644 --- a/src/aleph/db/models/balances.py +++ b/src/aleph/db/models/balances.py @@ -1,8 +1,18 @@ +import datetime as dt from decimal import Decimal from typing import Optional from aleph_message.models import Chain -from sqlalchemy import DECIMAL, BigInteger, Column, Integer, String, UniqueConstraint +from sqlalchemy import ( + DECIMAL, + TIMESTAMP, + BigInteger, + Column, + Integer, + String, + UniqueConstraint, +) +from sqlalchemy.sql import func from sqlalchemy_utils.types.choice import ChoiceType from .base import Base @@ -18,6 +28,12 @@ class AlephBalanceDb(Base): dapp: Optional[str] = Column(String, nullable=True) eth_height: int = Column(Integer, nullable=False) balance: Decimal = Column(DECIMAL, nullable=False) + last_update: dt.datetime = Column( + TIMESTAMP(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) __table_args__ = ( UniqueConstraint( diff --git a/src/aleph/db/models/cron_jobs.py b/src/aleph/db/models/cron_jobs.py new file mode 100644 index 000000000..09d12fd42 --- /dev/null +++ b/src/aleph/db/models/cron_jobs.py @@ -0,0 +1,14 @@ +import datetime as dt + +from sqlalchemy import TIMESTAMP, Column, Integer, String + +from .base import Base + + +class CronJobDb(Base): + __tablename__ = "cron_jobs" + + id: str = Column(String, primary_key=True) + # Interval is specified in seconds + interval: int = Column(Integer, nullable=False) + last_run: dt.datetime = Column(TIMESTAMP(timezone=True), nullable=False) diff --git a/src/aleph/db/models/files.py b/src/aleph/db/models/files.py index 8dac0178e..ac80f7edd 100644 --- a/src/aleph/db/models/files.py +++ b/src/aleph/db/models/files.py @@ -77,6 +77,9 @@ class FilePinDb(Base): owner = Column(String, nullable=True, index=True) item_hash = Column(String, nullable=True) + # Allow to recover MESSAGE pins refs marked for removing from grace period entries + ref = Column(String, nullable=True) + file: StoredFileDb = relationship(StoredFileDb, back_populates="pins") __mapper_args__: Dict[str, Any] = { @@ -94,7 +97,7 @@ class TxFilePinDb(FilePinDb): class MessageFilePinDb(FilePinDb): - ref = Column(String, nullable=True) + # ref = Column(String, nullable=True) __mapper_args__ = { "polymorphic_identity": FilePinType.MESSAGE.value, diff --git a/src/aleph/handlers/content/forget.py b/src/aleph/handlers/content/forget.py index 99dfefa83..5139175f4 100644 --- a/src/aleph/handlers/content/forget.py +++ b/src/aleph/handlers/content/forget.py @@ -113,10 +113,15 @@ async def check_permissions(self, session: DbSession, message: MessageDb): if target_status.status in ( MessageStatus.FORGOTTEN, MessageStatus.REJECTED, + MessageStatus.REMOVED, ): continue - if target_status.status != MessageStatus.PROCESSED: + # Note: Only allow to forget messages that are processed or marked for removing + if ( + target_status.status != MessageStatus.PROCESSED + and target_status.status != MessageStatus.REMOVING + ): raise ForgetTargetNotFound(target_hash=target_hash) target_message = get_message_by_item_hash( @@ -179,6 +184,8 @@ async def _forget_item_hash( if message_status.status == MessageStatus.REJECTED: logger.info("Message %s was rejected, nothing to do.", item_hash) + if message_status.status == MessageStatus.REMOVED: + logger.info("Message %s was removed, nothing to do.", item_hash) if message_status.status == MessageStatus.FORGOTTEN: logger.info("Message %s is already forgotten, nothing to do.", item_hash) append_to_forgotten_by( @@ -188,7 +195,11 @@ async def _forget_item_hash( ) return - if message_status.status != MessageStatus.PROCESSED: + # Note: Only allow to forget messages that are processed or marked for removing + if ( + message_status.status != MessageStatus.PROCESSED + and message_status.status != MessageStatus.REMOVING + ): logger.error( "FORGET message %s targets message %s which is not processed yet. This should not happen.", forgotten_by.item_hash, diff --git a/src/aleph/handlers/content/store.py b/src/aleph/handlers/content/store.py index 9e732530c..eeade47b4 100644 --- a/src/aleph/handlers/content/store.py +++ b/src/aleph/handlers/content/store.py @@ -9,7 +9,7 @@ import datetime as dt import logging from decimal import Decimal -from typing import List, Optional, Set +from typing import List, Set import aioipfs from aleph_message.models import ItemHash, ItemType, StoreContent @@ -39,7 +39,7 @@ from aleph.toolkit.costs import are_store_and_program_free from aleph.toolkit.timestamp import timestamp_to_datetime, utc_now from aleph.types.db_session import DbSession -from aleph.types.files import FileTag, FileType +from aleph.types.files import FileType from aleph.types.message_status import ( FileUnavailable, InsufficientBalanceException, @@ -48,7 +48,7 @@ StoreCannotUpdateStoreWithRef, StoreRefNotFound, ) -from aleph.utils import item_type_from_hash +from aleph.utils import item_type_from_hash, make_file_tag LOGGER = logging.getLogger(__name__) @@ -62,36 +62,6 @@ def _get_store_content(message: MessageDb) -> StoreContent: return content -def make_file_tag(owner: str, ref: Optional[str], item_hash: str) -> FileTag: - """ - Builds the file tag corresponding to a STORE message. - - The file tag can be set to two different values: - * if the `ref` field is not set, the tag will be set to . - * if the `ref` field is set, two cases: if `ref` is an item hash, the tag is - the value of the ref field. If it is a user-defined value, the tag is - /. - - :param owner: Owner of the file. - :param ref: Value of the `ref` field of the message content. - :param item_hash: Item hash of the message. - :return: The computed file tag. - """ - - # When the user does not specify a ref, we use the item hash. - if ref is None: - return FileTag(item_hash) - - # If ref is an item hash, return it as is - try: - _item_hash = ItemHash(ref) - return FileTag(ref) - except ValueError: - pass - - return FileTag(f"{owner}/{ref}") - - class StoreMessageHandler(ContentHandler): def __init__(self, storage_service: StorageService, grace_period: int): self.storage_service = storage_service diff --git a/src/aleph/jobs/cron/balance_job.py b/src/aleph/jobs/cron/balance_job.py new file mode 100644 index 000000000..2651a2961 --- /dev/null +++ b/src/aleph/jobs/cron/balance_job.py @@ -0,0 +1,148 @@ +import datetime as dt +import logging +from typing import List + +from aleph_message.models import MessageType, PaymentType + +from aleph.db.accessors.balances import get_total_balance, get_updated_balance_accounts +from aleph.db.accessors.cost import get_total_costs_for_address_grouped_by_message +from aleph.db.accessors.files import update_file_pin_grace_period +from aleph.db.accessors.messages import ( + get_message_by_item_hash, + get_message_status, + make_message_status_upsert_query, +) +from aleph.db.models.cron_jobs import CronJobDb +from aleph.db.models.messages import MessageStatusDb +from aleph.jobs.cron.cron_job import BaseCronJob +from aleph.services.cost import calculate_storage_size +from aleph.toolkit.constants import ( + MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE, + STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT, + MiB, +) +from aleph.toolkit.timestamp import utc_now +from aleph.types.db_session import DbSession, DbSessionFactory +from aleph.types.message_status import MessageStatus + +LOGGER = logging.getLogger(__name__) + + +class BalanceCronJob(BaseCronJob): + def __init__(self, session_factory: DbSessionFactory): + self.session_factory = session_factory + + async def run(self, now: dt.datetime, job: CronJobDb): + with self.session_factory() as session: + accounts = get_updated_balance_accounts(session, job.last_run) + + LOGGER.info(f"Checking '{len(accounts)}' updated account balances...") + + for address in accounts: + remaining_balance = get_total_balance(session, address) + + to_delete = [] + to_recover = [] + + hold_costs = get_total_costs_for_address_grouped_by_message( + session, address, PaymentType.hold + ) + + for item_hash, height, cost, _ in hold_costs: + status = get_message_status(session, item_hash) + + LOGGER.info( + f"Checking {item_hash} message, with height {height} and cost {cost}" + ) + + should_remove = remaining_balance < cost and ( + height is not None + and height >= STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT + ) + remaining_balance = max(0, remaining_balance - cost) + + status = get_message_status(session, item_hash) + if status is None: + continue + + if should_remove: + if ( + status.status != MessageStatus.REMOVING + and status.status != MessageStatus.REMOVED + ): + to_delete.append(item_hash) + else: + if status.status == MessageStatus.REMOVING: + to_recover.append(item_hash) + + if len(to_delete) > 0: + LOGGER.info( + f"'{len(to_delete)}' messages to delete for account '{address}'..." + ) + await self.delete_messages(session, to_delete) + + if len(to_recover) > 0: + LOGGER.info( + f"'{len(to_recover)}' messages to recover for account '{address}'..." + ) + await self.recover_messages(session, to_recover) + + session.commit() + + async def delete_messages(self, session: DbSession, messages: List[str]): + for item_hash in messages: + message = get_message_by_item_hash(session, item_hash) + + if message is None: + continue + + if message.type == MessageType.store: + storage_size_mib = calculate_storage_size( + session, message.parsed_content + ) + + if storage_size_mib and storage_size_mib <= ( + MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE / MiB + ): + continue + + now = utc_now() + delete_by = now + dt.timedelta(hours=24 + 1) + + if message.type == MessageType.store: + update_file_pin_grace_period( + session=session, + item_hash=item_hash, + delete_by=delete_by, + ) + + session.execute( + make_message_status_upsert_query( + item_hash=item_hash, + new_status=MessageStatus.REMOVING, + reception_time=now, + where=(MessageStatusDb.status == MessageStatus.PROCESSED), + ) + ) + + async def recover_messages(self, session: DbSession, messages: List[str]): + for item_hash in messages: + message = get_message_by_item_hash(session, item_hash) + if message is None: + continue + + if message.type == MessageType.store: + update_file_pin_grace_period( + session=session, + item_hash=item_hash, + delete_by=None, + ) + + session.execute( + make_message_status_upsert_query( + item_hash=item_hash, + new_status=MessageStatus.PROCESSED, + reception_time=utc_now(), + where=(MessageStatusDb.status == MessageStatus.REMOVING), + ) + ) diff --git a/src/aleph/jobs/cron/cron_job.py b/src/aleph/jobs/cron/cron_job.py new file mode 100644 index 000000000..8818705ca --- /dev/null +++ b/src/aleph/jobs/cron/cron_job.py @@ -0,0 +1,91 @@ +import abc +import asyncio +import datetime as dt +import logging +from typing import Coroutine, Dict, List + +from configmanager import Config + +from aleph.db.accessors.cron_jobs import get_cron_jobs, update_cron_job +from aleph.db.models.cron_jobs import CronJobDb +from aleph.toolkit.timestamp import utc_now +from aleph.types.db_session import DbSession, DbSessionFactory + +LOGGER = logging.getLogger(__name__) + + +class BaseCronJob(abc.ABC): + @abc.abstractmethod + async def run(self, now: dt.datetime, job: CronJobDb) -> None: + pass + + +class CronJob: + def __init__(self, session_factory: DbSessionFactory, jobs: Dict[str, BaseCronJob]): + self.session_factory = session_factory + self.jobs = jobs + + async def __run_job( + self, + session: DbSession, + cron_job: BaseCronJob, + now: dt.datetime, + job: CronJobDb, + ): + try: + LOGGER.info(f"Starting '{job.id}' cron job check...") + await cron_job.run(now, job) + + update_cron_job(session, job.id, now) + + LOGGER.info(f"'{job.id}' cron job ran successfully.") + + except Exception: + LOGGER.exception( + f"An unexpected error occurred during '{job.id}' cron job execution." + ) + + async def run(self, now: dt.datetime): + with self.session_factory() as session: + jobs = get_cron_jobs(session) + jobs_to_run: List[Coroutine] = [] + + for job in jobs: + interval = dt.timedelta(seconds=job.interval) + run_datetime = job.last_run + interval + + if now >= run_datetime: + cron_job = self.jobs.get(job.id) + + if cron_job: + jobs_to_run.append(self.__run_job(session, cron_job, now, job)) + LOGGER.info( + f"'{job.id}' cron job scheduled for running successfully." + ) + + await asyncio.gather(*jobs_to_run) + + session.commit() + + +async def cron_job_task(config: Config, cron_job: CronJob) -> None: + interval = dt.timedelta(hours=config.aleph.jobs.cron.period.value) + + # Start by waiting, this gives the node time to start up and process potential pending + # messages that could pin files. + LOGGER.info("Warming up cron job runner... next run: %s.", utc_now() + interval) + await asyncio.sleep(interval.total_seconds()) + + while True: + try: + now = utc_now() + + LOGGER.info("Starting cron job check...") + await cron_job.run(now=now) + LOGGER.info("Cron job ran successfully.") + + LOGGER.info("Next cron job run: %s.", now + interval) + await asyncio.sleep(interval.total_seconds()) + + except Exception: + LOGGER.exception("An unexpected error occurred during cron job check.") diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index 24f0d4643..73fee44e9 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -28,7 +28,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer from aleph.db.models import MessageDb -from aleph.types.message_status import ErrorCode, MessageStatus +from aleph.types.message_status import ErrorCode, MessageStatus, RemovedMessageReason MType = TypeVar("MType", bound=MessageType) ContentType = TypeVar("ContentType", bound=BaseContent) @@ -82,16 +82,24 @@ class ForgetMessage( ): ... -class InstanceMessage(BaseMessage[Literal[MessageType.instance], InstanceContent]): ... # type: ignore +class InstanceMessage( + BaseMessage[Literal[MessageType.instance], InstanceContent] # type: ignore +): ... -class PostMessage(BaseMessage[Literal[MessageType.post], PostContent]): ... # type: ignore +class PostMessage( + BaseMessage[Literal[MessageType.post], PostContent] # type: ignore +): ... -class ProgramMessage(BaseMessage[Literal[MessageType.program], ProgramContent]): ... # type: ignore +class ProgramMessage( + BaseMessage[Literal[MessageType.program], ProgramContent] # type: ignore +): ... -class StoreMessage(BaseMessage[Literal[MessageType.store], StoreContent]): ... # type: ignore +class StoreMessage( + BaseMessage[Literal[MessageType.store], StoreContent] # type: ignore +): ... MESSAGE_CLS_DICT: Dict[ @@ -179,6 +187,22 @@ class ProcessedMessageStatus(BaseMessageStatus): message: AlephMessage +class RemovingMessageStatus(BaseMessageStatus): + model_config = ConfigDict(from_attributes=True) + + status: MessageStatus = MessageStatus.REMOVING + message: AlephMessage + reason: RemovedMessageReason + + +class RemovedMessageStatus(BaseMessageStatus): + model_config = ConfigDict(from_attributes=True) + + status: MessageStatus = MessageStatus.REMOVED + message: AlephMessage + reason: RemovedMessageReason + + class ForgottenMessage(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -218,6 +242,8 @@ class MessageHashes(BaseMessageStatus): ProcessedMessageStatus, ForgottenMessageStatus, RejectedMessageStatus, + RemovingMessageStatus, + RemovedMessageStatus, ] diff --git a/src/aleph/services/storage/garbage_collector.py b/src/aleph/services/storage/garbage_collector.py index 49cb2df81..b09531fd7 100644 --- a/src/aleph/services/storage/garbage_collector.py +++ b/src/aleph/services/storage/garbage_collector.py @@ -3,14 +3,26 @@ import logging from aioipfs import NotPinnedError -from aleph_message.models import ItemHash, ItemType +from aleph_message.models import ItemHash, ItemType, MessageType from configmanager import Config +from aleph.db.accessors.cost import delete_costs_for_forgotten_and_deleted_messages from aleph.db.accessors.files import delete_file as delete_file_db -from aleph.db.accessors.files import delete_grace_period_file_pins, get_unpinned_files +from aleph.db.accessors.files import ( + delete_grace_period_file_pins, + file_pin_exists, + get_unpinned_files, +) +from aleph.db.accessors.messages import ( + get_matching_hashes, + get_one_message_by_item_hash, + make_message_status_upsert_query, +) +from aleph.db.models.messages import MessageStatusDb from aleph.storage import StorageService from aleph.toolkit.timestamp import utc_now from aleph.types.db_session import DbSessionFactory +from aleph.types.message_status import MessageStatus LOGGER = logging.getLogger(__name__) @@ -42,6 +54,69 @@ async def _delete_from_local_storage(self, file_hash: ItemHash): await self.storage_service.storage_engine.delete(file_hash) LOGGER.debug(f"Removed from local storage: {file_hash}") + async def _check_and_update_removing_messages(self): + """ + Check all messages with status REMOVING and update to REMOVED if their resources + have been fully deleted. + """ + LOGGER.info("Checking messages with REMOVING status") + + with self.session_factory() as session: + # Get all messages with REMOVING status + removing_messages = list( + get_matching_hashes( + session=session, + status=MessageStatus.REMOVING, + hash_only=False, + pagination=0, # Get all matching messages + ) + ) + + LOGGER.info( + "Found %d messages with REMOVING status", len(removing_messages) + ) + + for message_status in removing_messages: + item_hash = message_status.item_hash + try: + # For STORE messages, check if the file is still pinned + # We need to get message details to check its type + message = get_one_message_by_item_hash( + session=session, item_hash=item_hash + ) + + resources_deleted = True + + if message and message.type == MessageType.store: + # Check if the file is still pinned (by item_hash cause there could be other messages pinning the same file_hash) + if file_pin_exists(session=session, item_hash=item_hash): + resources_deleted = False + + # If all resources have been deleted, update status to REMOVED + if resources_deleted: + now = utc_now() + session.execute( + make_message_status_upsert_query( + item_hash=item_hash, + new_status=MessageStatus.REMOVED, + reception_time=now, + where=( + MessageStatusDb.status == MessageStatus.REMOVING + ), + ) + ) + + except Exception as err: + LOGGER.error( + "Failed to check or update message status %s: %s", + item_hash, + str(err), + ) + + delete_costs_for_forgotten_and_deleted_messages(session=session) + + session.commit() + async def collect(self, datetime: dt.datetime): with self.session_factory() as session: # Delete outdated grace period file pins @@ -59,18 +134,22 @@ async def collect(self, datetime: dt.datetime): LOGGER.info("Deleting %s...", file_hash) delete_file_db(session=session, file_hash=file_hash) - session.commit() if file_hash.item_type == ItemType.ipfs: await self._delete_from_ipfs(file_hash) elif file_hash.item_type == ItemType.storage: await self._delete_from_local_storage(file_hash) + session.commit() + LOGGER.info("Deleted %s", file_hash) except Exception as err: LOGGER.error("Failed to delete file %s: %s", file_hash, str(err)) session.rollback() + # After collecting garbage, check and update message status + await self._check_and_update_removing_messages() + async def garbage_collector_task( config: Config, garbage_collector: GarbageCollector diff --git a/src/aleph/toolkit/constants.py b/src/aleph/toolkit/constants.py index b70ef7137..d33ccc1ad 100644 --- a/src/aleph/toolkit/constants.py +++ b/src/aleph/toolkit/constants.py @@ -274,8 +274,8 @@ "community_wallet_timestamp": 1739301770, } -STORE_AND_PROGRAM_COST_DEADLINE_HEIGHT = 22388870 -STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP = 1746101025 +STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT = 22196000 +STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP = 1743775079 MAX_FILE_SIZE = 100 * MiB MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE = 25 * MiB diff --git a/src/aleph/toolkit/costs.py b/src/aleph/toolkit/costs.py index 20b52329b..c706e7703 100644 --- a/src/aleph/toolkit/costs.py +++ b/src/aleph/toolkit/costs.py @@ -5,8 +5,8 @@ from aleph.db.models.messages import MessageDb from aleph.toolkit.constants import ( PRICE_PRECISION, - STORE_AND_PROGRAM_COST_DEADLINE_HEIGHT, - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP, + STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT, + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP, ) from aleph.toolkit.timestamp import timestamp_to_datetime @@ -27,6 +27,6 @@ def are_store_and_program_free(message: MessageDb) -> bool: date: dt.datetime = message.time if height is not None: - return height < STORE_AND_PROGRAM_COST_DEADLINE_HEIGHT + return height < STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT else: - return date < timestamp_to_datetime(STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP) + return date < timestamp_to_datetime(STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP) diff --git a/src/aleph/types/message_status.py b/src/aleph/types/message_status.py index 406696e79..cc0fa8d83 100644 --- a/src/aleph/types/message_status.py +++ b/src/aleph/types/message_status.py @@ -14,6 +14,8 @@ class MessageStatus(str, Enum): PROCESSED = "processed" REJECTED = "rejected" FORGOTTEN = "forgotten" + REMOVING = "removing" + REMOVED = "removed" class MessageProcessingStatus(str, Enum): @@ -56,6 +58,10 @@ class ErrorCode(IntEnum): FORGOTTEN_DUPLICATE = 504 +class RemovedMessageReason(str, Enum): + BALANCE_INSUFFICIENT = "balance_insufficient" + + class MessageProcessingException(Exception): error_code: ErrorCode diff --git a/src/aleph/utils.py b/src/aleph/utils.py index bba1901d5..3cf0b16c4 100644 --- a/src/aleph/utils.py +++ b/src/aleph/utils.py @@ -1,11 +1,12 @@ import asyncio from hashlib import sha256 -from typing import Union +from typing import Optional, Union -from aleph_message.models import ItemType +from aleph_message.models import ItemHash, ItemType from aleph.exceptions import UnknownHashError from aleph.settings import settings +from aleph.types.files import FileTag async def run_in_executor(executor, func, *args): @@ -40,3 +41,33 @@ def safe_getattr(obj, attr, default=None): if obj is default: break return obj + + +def make_file_tag(owner: str, ref: Optional[str], item_hash: str) -> FileTag: + """ + Builds the file tag corresponding to a STORE message. + + The file tag can be set to two different values: + * if the `ref` field is not set, the tag will be set to . + * if the `ref` field is set, two cases: if `ref` is an item hash, the tag is + the value of the ref field. If it is a user-defined value, the tag is + /. + + :param owner: Owner of the file. + :param ref: Value of the `ref` field of the message content. + :param item_hash: Item hash of the message. + :return: The computed file tag. + """ + + # When the user does not specify a ref, we use the item hash. + if ref is None: + return FileTag(item_hash) + + # If ref is an item hash, return it as is + try: + _item_hash = ItemHash(ref) + return FileTag(ref) + except ValueError: + pass + + return FileTag(f"{owner}/{ref}") diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index ca4535728..0f1f4bfad 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -40,12 +40,14 @@ PostMessage, ProcessedMessageStatus, RejectedMessageStatus, + RemovedMessageStatus, + RemovingMessageStatus, format_message, format_message_dict, ) from aleph.toolkit.shield import shielded from aleph.types.db_session import DbSession, DbSessionFactory -from aleph.types.message_status import MessageStatus +from aleph.types.message_status import MessageStatus, RemovedMessageReason from aleph.types.sort_order import SortBy, SortOrder from aleph.web.controllers.app_state_getters import ( get_config_from_request, @@ -565,6 +567,36 @@ def _get_message_with_status( message=rejected_message_db.message, ) + if status == MessageStatus.REMOVING: + message_db = get_message_by_item_hash( + session=session, item_hash=ItemHash(item_hash) + ) + if not message_db: + raise web.HTTPNotFound() + + message = format_message(message_db) + return RemovingMessageStatus( + item_hash=item_hash, + reception_time=reception_time, + message=message, + reason=RemovedMessageReason.BALANCE_INSUFFICIENT, + ) + + if status == MessageStatus.REMOVED: + message_db = get_message_by_item_hash( + session=session, item_hash=ItemHash(item_hash) + ) + if not message_db: + raise web.HTTPNotFound() + + message = format_message(message_db) + return RemovedMessageStatus( + item_hash=item_hash, + reception_time=reception_time, + message=message, + reason=RemovedMessageReason.BALANCE_INSUFFICIENT, + ) + raise NotImplementedError(f"Unknown message status: {status}") diff --git a/src/aleph/web/controllers/prices.py b/src/aleph/web/controllers/prices.py index 5d22ed3ce..48fc9df32 100644 --- a/src/aleph/web/controllers/prices.py +++ b/src/aleph/web/controllers/prices.py @@ -46,6 +46,10 @@ class HTTPProcessing(HTTPException): web.HTTPGone, "This message has been forgotten", ), + MessageStatus.REMOVED: ( + web.HTTPGone, + "This message has been removed", + ), } diff --git a/tests/jobs/test_balance_job.py b/tests/jobs/test_balance_job.py new file mode 100644 index 000000000..c3f4371d6 --- /dev/null +++ b/tests/jobs/test_balance_job.py @@ -0,0 +1,387 @@ +import datetime as dt +from decimal import Decimal + +import pytest +import pytest_asyncio +from aleph_message.models import Chain, ItemType, MessageType, PaymentType + +from aleph.db.accessors.messages import get_message_status +from aleph.db.models.account_costs import AccountCostsDb +from aleph.db.models.balances import AlephBalanceDb +from aleph.db.models.chains import ChainTxDb +from aleph.db.models.cron_jobs import CronJobDb +from aleph.db.models.files import ( + FilePinType, + GracePeriodFilePinDb, + MessageFilePinDb, + StoredFileDb, +) +from aleph.db.models.messages import MessageDb, MessageStatusDb +from aleph.jobs.cron.balance_job import BalanceCronJob +from aleph.toolkit.constants import STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT, MiB +from aleph.toolkit.timestamp import utc_now +from aleph.types.chain_sync import ChainSyncProtocol +from aleph.types.cost import CostType +from aleph.types.db_session import DbSessionFactory +from aleph.types.files import FileType +from aleph.types.message_status import MessageStatus + + +@pytest.fixture +def balance_job(session_factory: DbSessionFactory) -> BalanceCronJob: + return BalanceCronJob(session_factory=session_factory) + + +@pytest.fixture +def now(): + return utc_now() + + +def create_cron_job(id, now): + """Create a cron job entry for testing.""" + return CronJobDb( + id=id, + interval=1, + last_run=now - dt.timedelta(hours=1), + ) + + +def create_wallet(address, balance, now): + """Create a wallet with the specified balance.""" + return AlephBalanceDb( + address=address, + balance=Decimal(str(balance)), + last_update=now, + chain=Chain.ETH, + eth_height=STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT, + ) + + +def create_store_message( + item_hash, + sender, + file_hash, + now, + size=30 * MiB, + status=MessageStatus.PROCESSED, +): + """Create a store message with associated file and status.""" + message = MessageDb( + item_hash=item_hash, + sender=sender, + chain=Chain.ETH, + type=MessageType.store, + time=now, + item_type=ItemType.ipfs, + signature=f"sig_{item_hash[:8]}", + size=size, + content={ + "address": "0xB68B9D4f3771c246233823ed1D3Add451055F9Ef", + "time": 1645794065.439, + "hashes": ["QmTQPocJ8n3r7jhwYxmCDR5bJ4SNsEhdVm8WwkNbGctgJF"], + "reason": "None", + "type": "TEST", + "item_hash": file_hash, + "item_type": ItemType.ipfs.value, + }, + ) + + file = StoredFileDb( + hash=file_hash, + size=size, + type=FileType.FILE, + ) + + if status == MessageStatus.PROCESSED: + file_pin = MessageFilePinDb( + item_hash=item_hash, + file_hash=file_hash, + type=FilePinType.MESSAGE, + owner=sender, + created=now, + ) + elif status == MessageStatus.REMOVING: + file_pin = GracePeriodFilePinDb( + item_hash=item_hash, + file_hash=file_hash, + type=FilePinType.GRACE_PERIOD, + owner=sender, + created=now, + delete_by=now + dt.timedelta(hours=24), + ) + + message_status = MessageStatusDb( + item_hash=item_hash, + status=status, + reception_time=now, + ) + + return message, file, file_pin, message_status + + +def create_message_cost(owner, item_hash, cost_hold): + """Create a cost record for a message.""" + return AccountCostsDb( + owner=owner, + item_hash=item_hash, + type=CostType.STORAGE, + name="store", + payment_type=PaymentType.hold, + cost_hold=Decimal(str(cost_hold)), + cost_stream=Decimal("0.0"), + ) + + +def add_chain_confirmation(message, height, now): + """Add a chain confirmation with specified height to a message.""" + chain_confirm = ChainTxDb( + hash="0x111", + chain=Chain.ETH, + height=height, + datetime=now, + publisher="0xabadbabe", + protocol=ChainSyncProtocol.OFF_CHAIN_SYNC, + protocol_version=1, + content="Qmsomething", + ) + message.confirmations = [chain_confirm] + return message + + +@pytest_asyncio.fixture +async def fixture_base_data(session_factory, now): + """Create base data that can be used by multiple tests.""" + # Create cron job + cron_job = create_cron_job("balance_check_base", now) + + with session_factory() as session: + session.add(cron_job) + session.commit() + + return {"cron_job_name": "balance_check_base"} + + +@pytest_asyncio.fixture +async def fixture_message_for_removal(session_factory, now, fixture_base_data): + """ + Setup for testing a message that should be marked for removal due to insufficient balance. + """ + wallet_address = "0xtestaddress1" + message_hash = "abcd1234" * 4 + file_hash = "1234" * 16 + + # Create wallet with low balance + wallet = create_wallet(wallet_address, "10.0", now) + + # Create message and associated records + message, file, file_pin, message_status = create_store_message( + message_hash, wallet_address, file_hash, now + ) + + # Add message cost (more than wallet balance) + message_cost = create_message_cost(wallet_address, message_hash, "15.0") + + # Add chain confirmation with height above cutoff + message = add_chain_confirmation( + message, STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT + 1000, now + ) + + with session_factory() as session: + session.add_all([message]) + session.commit() + + session.add_all([wallet, file, file_pin, message_status, message_cost]) + session.commit() + + return { + "wallet_address": wallet_address, + "message_hash": message_hash, + } + + +@pytest_asyncio.fixture +async def fixture_message_below_cutoff(session_factory, now, fixture_base_data): + """ + Setup for testing a message that should not be marked for removal + because its height is below the cutoff. + """ + wallet_address = "0xtestaddress2" + message_hash = "bcde2345" * 4 + file_hash = "1234" * 16 + + # Create wallet with low balance + wallet = create_wallet(wallet_address, "5.0", now) + + # Create message and associated records + message, file, file_pin, message_status = create_store_message( + message_hash, wallet_address, file_hash, now + ) + + # Add message cost (more than wallet balance) + message_cost = create_message_cost(wallet_address, message_hash, "10.0") + + # Add chain confirmation with height BELOW cutoff + message = add_chain_confirmation( + message, STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT - 1000, now + ) + + with session_factory() as session: + session.add_all([message]) + session.commit() + + session.add_all([wallet, file, file_pin, message_status, message_cost]) + session.commit() + + return { + "wallet_address": wallet_address, + "message_hash": message_hash, + } + + +@pytest_asyncio.fixture +async def fixture_message_for_recovery(session_factory, now, fixture_base_data): + """ + Setup for testing a message that should be recovered from REMOVING status + because the wallet balance is now sufficient. + """ + wallet_address = "0xtestaddress3" + message_hash = "cdef3456" * 4 + file_hash = "1234" * 16 + + # Create wallet with sufficient balance + wallet = create_wallet(wallet_address, "50.0", now) + + # Create message and associated records with REMOVING status + message, file, file_pin, _ = create_store_message( + message_hash, + wallet_address, + file_hash, + now, + status=MessageStatus.REMOVING, # Set status to REMOVING + ) + + # Override the message status to ensure it's REMOVING + message_status = MessageStatusDb( + item_hash=message_hash, + status=MessageStatus.REMOVING, + reception_time=now, + ) + + # Add message cost (less than wallet balance now) + message_cost = create_message_cost(wallet_address, message_hash, "20.0") + + # Add chain confirmation with height above cutoff + message = add_chain_confirmation( + message, STORE_AND_PROGRAM_COST_CUTOFF_HEIGHT + 2000, now + ) + + with session_factory() as session: + session.add_all([message]) + session.commit() + + session.add_all([wallet, file, file_pin, message_status, message_cost]) + session.commit() + + return { + "wallet_address": wallet_address, + "message_hash": message_hash, + } + + +@pytest.mark.asyncio +async def test_balance_job_marks_messages_for_removal( + session_factory, balance_job, fixture_message_for_removal, now +): + """Test that the balance job marks messages for removal when balance is insufficient.""" + # Get the cron job + with session_factory() as session: + cron_job = session.query(CronJobDb).filter_by(id="balance_check_base").one() + + # Run the balance job + await balance_job.run(now, cron_job) + + # Check if the message was marked for removal + with session_factory() as session: + # Check message status changed to REMOVING + message_status = get_message_status( + session=session, item_hash=fixture_message_for_removal["message_hash"] + ) + assert message_status is not None + assert message_status.status == MessageStatus.REMOVING + + # Check if a grace period was added to the file pin + grace_period_pins = ( + session.query(GracePeriodFilePinDb) + .filter_by(item_hash=fixture_message_for_removal["message_hash"]) + .all() + ) + + assert len(grace_period_pins) == 1 + assert grace_period_pins[0].delete_by is not None + + # Delete should be around 25 hours in the future (24+1 as specified in the code) + delete_by = grace_period_pins[0].delete_by + time_diff = delete_by - now + assert ( + 24.5 <= time_diff.total_seconds() / 3600 <= 25.5 + ) # Between 24.5 and 25.5 hours + + +@pytest.mark.asyncio +async def test_balance_job_ignores_messages_below_cutoff_height( + session_factory, balance_job, fixture_message_below_cutoff, now +): + """Test that the balance job ignores messages with height below the cutoff.""" + # Get the cron job + with session_factory() as session: + cron_job = session.query(CronJobDb).filter_by(id="balance_check_base").one() + + # Run the balance job + await balance_job.run(now, cron_job) + + # Check that the message was NOT marked for removal (still PROCESSED) + with session_factory() as session: + message_status = get_message_status( + session=session, item_hash=fixture_message_below_cutoff["message_hash"] + ) + assert message_status is not None + assert message_status.status == MessageStatus.PROCESSED + + # Check no grace period was added + grace_period_pins = ( + session.query(GracePeriodFilePinDb) + .filter_by(item_hash=fixture_message_below_cutoff["message_hash"]) + .all() + ) + + assert len(grace_period_pins) == 0 + + +@pytest.mark.asyncio +async def test_balance_job_recovers_messages_with_sufficient_balance( + session_factory, balance_job, fixture_message_for_recovery, now +): + """Test that the balance job recovers messages with REMOVING status when balance is sufficient.""" + # Get the cron job + with session_factory() as session: + cron_job = session.query(CronJobDb).filter_by(id="balance_check_base").one() + + # Run the balance job + await balance_job.run(now, cron_job) + + # Check that the message was recovered (marked as PROCESSED again) + with session_factory() as session: + message_status = get_message_status( + session=session, item_hash=fixture_message_for_recovery["message_hash"] + ) + assert message_status is not None + assert message_status.status == MessageStatus.PROCESSED + + # Check grace period was updated to null (no deletion date) + grace_period_pins = ( + session.query(GracePeriodFilePinDb) + .filter_by(item_hash=fixture_message_for_recovery["message_hash"]) + .all() + ) + + assert len(grace_period_pins) == 0 diff --git a/tests/jobs/test_check_removing_messages.py b/tests/jobs/test_check_removing_messages.py new file mode 100644 index 000000000..9c3d0a74f --- /dev/null +++ b/tests/jobs/test_check_removing_messages.py @@ -0,0 +1,149 @@ +import pytest +import pytest_asyncio +from aleph_message.models import Chain, ItemType, MessageType + +from aleph.db.accessors.messages import get_message_status +from aleph.db.models.files import FilePinType, MessageFilePinDb, StoredFileDb +from aleph.db.models.messages import MessageDb, MessageStatusDb +from aleph.services.storage.garbage_collector import GarbageCollector +from aleph.storage import StorageService +from aleph.toolkit.timestamp import utc_now +from aleph.types.db_session import DbSessionFactory +from aleph.types.files import FileType +from aleph.types.message_status import MessageStatus + + +@pytest.fixture +def gc( + session_factory: DbSessionFactory, test_storage_service: StorageService +) -> GarbageCollector: + return GarbageCollector( + session_factory=session_factory, storage_service=test_storage_service + ) + + +@pytest_asyncio.fixture +async def fixture_removing_messages(session_factory: DbSessionFactory): + # Set up test data with messages in REMOVING status + now = utc_now() + + # Create test data + store_message_hash = "abcd" * 16 + store_message_file_hash = "1234" * 16 + + # Message with REMOVING status that should be changed to REMOVED (no pinned files) + store_message = MessageDb( + item_hash=store_message_hash, + sender="0xsender1", + chain=Chain.ETH, + type=MessageType.store, + time=now, + item_type=ItemType.ipfs, + signature="sig1", + size=1000, + content={ + "type": "TEST", + "item_hash": store_message_file_hash, + "item_type": ItemType.ipfs.value, + }, + ) + + # Create file reference + store_file = StoredFileDb( + hash=store_message_file_hash, + size=1000, + type=FileType.FILE, + # No pins - file is no longer pinned + ) + + # Message status with REMOVING + store_message_status = MessageStatusDb( + item_hash=store_message_hash, + status=MessageStatus.REMOVING, + reception_time=now, + ) + + # Message that should stay in REMOVING status (file still pinned) + pinned_message_hash = "efgh" * 16 + pinned_file_hash = "5678" * 16 + + pinned_message = MessageDb( + item_hash=pinned_message_hash, + sender="0xsender2", + chain=Chain.ETH, + type=MessageType.store, + time=now, + item_type=ItemType.ipfs, + signature="sig2", + size=1000, + content={ + "type": "TEST", + "item_hash": pinned_file_hash, + "item_type": ItemType.ipfs.value, + }, + ) + + # Create file with pins + pinned_file = StoredFileDb( + hash=pinned_file_hash, + size=2000, + type=FileType.FILE, + ) + + # Create a separate pin for the file + pinned_file_pin = MessageFilePinDb( + item_hash=pinned_message_hash, + file_hash=pinned_file_hash, + type=FilePinType.MESSAGE, + created=now, + owner="0xowner1", + ) + + # Message status with REMOVING + pinned_message_status = MessageStatusDb( + item_hash=pinned_message_hash, + status=MessageStatus.REMOVING, + reception_time=now, + ) + + with session_factory() as session: + session.add_all( + [ + store_message, + store_file, + store_message_status, + pinned_message, + pinned_file, + pinned_file_pin, + pinned_message_status, + ] + ) + session.commit() + + yield { + "removable_message": store_message_hash, + "pinned_message": pinned_message_hash, + } + + +@pytest.mark.asyncio +async def test_check_and_update_removing_messages( + session_factory: DbSessionFactory, gc: GarbageCollector, fixture_removing_messages +): + # Run the function that checks and updates message status + await gc._check_and_update_removing_messages() + + with session_factory() as session: + # The message with no pinned files should now have REMOVED status + removable_status = get_message_status( + session=session, item_hash=fixture_removing_messages["removable_message"] + ) + assert removable_status is not None + assert removable_status.status == MessageStatus.REMOVED + + # The message with a pinned file should still have REMOVING status + pinned_status = get_message_status( + session=session, item_hash=fixture_removing_messages["pinned_message"] + ) + assert pinned_status is not None + assert pinned_status.status == MessageStatus.REMOVING diff --git a/tests/jobs/test_cron_job.py b/tests/jobs/test_cron_job.py new file mode 100644 index 000000000..16471ef1f --- /dev/null +++ b/tests/jobs/test_cron_job.py @@ -0,0 +1,34 @@ +import datetime as dt + +import pytest + +from aleph.jobs.cron.balance_job import BalanceCronJob +from aleph.jobs.cron.cron_job import CronJob +from aleph.types.db_session import DbSessionFactory + + +@pytest.fixture +def cron_job(session_factory: DbSessionFactory) -> CronJob: + return CronJob( + session_factory=session_factory, + jobs={"balance": BalanceCronJob(session_factory=session_factory)}, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "cron_run_datetime", + [ + dt.datetime(2040, 1, 1, tzinfo=dt.timezone.utc), + dt.datetime(2023, 6, 1, tzinfo=dt.timezone.utc), + dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc), + ], +) +async def test_balance_job_run( + session_factory: DbSessionFactory, + cron_job: CronJob, + cron_run_datetime: dt.datetime, +): + with session_factory() as session: + await cron_job.run(now=cron_run_datetime) + session.commit() diff --git a/tests/message_processing/test_process_stores.py b/tests/message_processing/test_process_stores.py index 884d608f5..6b1b3be5d 100644 --- a/tests/message_processing/test_process_stores.py +++ b/tests/message_processing/test_process_stores.py @@ -19,7 +19,7 @@ from aleph.storage import StorageService from aleph.toolkit.constants import ( MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE, - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP, + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP, ) from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.channel import Channel @@ -57,14 +57,14 @@ def fixture_ipfs_store_message() -> PendingMessageDb: signature="0xb9d164e6e43a8fcd341abc01eda47bed0333eaf480e888f2ed2ae0017048939d18850a33352e7281645e95e8673bad733499b6a8ce4069b9da9b9a79ddc1a0b31b", item_type=ItemType.inline, item_content='{"address": "0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", "time": 1665478676.6585264, "item_type": "ipfs", "item_hash": "QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", "mime_type": "text/plain"}', - time=timestamp_to_datetime(STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1), + time=timestamp_to_datetime(STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1), channel=Channel("TEST"), check_message=True, retries=0, next_attempt=dt.datetime(2023, 1, 1), fetched=False, reception_time=timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ), ) @@ -79,14 +79,14 @@ def fixture_store_message_with_cost() -> PendingMessageDb: signature="0xb9d164e6e43a8fcd341abc01eda47bed0333eaf480e888f2ed2ae0017048939d18850a33352e7281645e95e8673bad733499b6a8ce4069b9da9b9a79ddc1a0b31b", item_type=ItemType.inline, item_content='{"address": "0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", "time": 1665478676.6585264, "item_type": "storage", "item_hash": "c25b0525bc308797d3e35763faf5c560f2974dab802cb4a734ae4e9d1040319e", "mime_type": "text/plain"}', - time=timestamp_to_datetime(STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1), + time=timestamp_to_datetime(STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1), channel=Channel("TEST"), check_message=True, retries=0, next_attempt=dt.datetime(2023, 1, 1), fetched=False, reception_time=timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ), ) @@ -96,7 +96,7 @@ def create_message_db(mocker): def _create_message( item_hash="test-hash", address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_content_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ): @@ -365,11 +365,11 @@ async def test_pre_check_balance_free_store_message( # Create a message with timestamp before the deadline message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP - 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP - 1 ) content = StoreContent( address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP - 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP - 1, item_type=ItemType.ipfs, item_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ) @@ -407,11 +407,11 @@ async def test_pre_check_balance_small_ipfs_file(mocker, session_factory, mock_c with session_factory() as session: message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ) @@ -455,11 +455,11 @@ async def test_pre_check_balance_large_ipfs_file_insufficient_balance( with session_factory() as session: message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ) @@ -509,11 +509,11 @@ async def test_pre_check_balance_large_ipfs_file_sufficient_balance( address = "0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106" message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address=address, - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ) @@ -561,11 +561,11 @@ async def test_pre_check_balance_non_ipfs_file(mocker, session_factory, mock_con # Create a message with a non-IPFS file type message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.storage, # Not IPFS item_hash="af2e19894099d954f3d1fa274547f62484bc2d93964658547deecc70316acc72", ) @@ -607,11 +607,11 @@ async def test_pre_check_balance_ipfs_disabled(mocker, session_factory): with session_factory() as session: message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ) @@ -645,11 +645,11 @@ async def test_pre_check_balance_ipfs_size_none(mocker, session_factory, mock_co with session_factory() as session: message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address="0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106", - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_hash="QmWVxvresoeadRbCeG4BmvsoSsqHV7VwUNuGK6nUCKKFGQ", ) @@ -701,11 +701,11 @@ async def test_pre_check_balance_with_existing_costs( address = "0x696879aE4F6d8DaDD5b8F1cbb1e663B89b08f106" message = mocker.MagicMock(spec=MessageDb) message.time = timestamp_to_datetime( - STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1 + STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1 ) content = StoreContent( address=address, - time=STORE_AND_PROGRAM_COST_DEADLINE_TIMESTAMP + 1, + time=STORE_AND_PROGRAM_COST_CUTOFF_TIMESTAMP + 1, item_type=ItemType.ipfs, item_hash="QmacDVDroxPVY1enhckVco1rTBziwC8hjf731apEKr3QoG", )