Skip to content

Commit

Permalink
feat: rip out contextlib.suppress for speed (#861)
Browse files Browse the repository at this point in the history
* feat: rip out contextlib.suppress for speed

* feat: rip out contextlib.suppress for speed

* chore: `black .`

* feat: rip out contextlib.suppress for speed

* chore: refactor (#862)

* chore: `black .`

* fix: ERC20 repr

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
BobTheBuidler and github-actions[bot] authored Dec 17, 2024
1 parent c0f66dc commit acd9a7f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 35 deletions.
19 changes: 10 additions & 9 deletions y/_db/utils/price.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import threading
from contextlib import suppress
from decimal import Decimal, InvalidOperation
from typing import Dict, Optional

import a_sync
from a_sync import ProcessingQueue
from brownie import chain
from cachetools import TTLCache, cached
from pony.orm import select
Expand All @@ -18,6 +17,7 @@


logger = logging.getLogger(__name__)
_logger_debug = logger.debug


@a_sync_read_db_session
Expand Down Expand Up @@ -45,12 +45,12 @@ def get_price(address: str, block: int) -> Optional[Decimal]:
if address == constants.EEE_ADDRESS:
address = constants.WRAPPED_GAS_COIN
if price := known_prices_at_block(block).pop(address, None):
logger.debug("found %s block %s price %s in ydb", address, block, price)
_logger_debug("found %s block %s price %s in ydb", address, block, price)
return price
if (price := Price.get(token=(chain.id, address), block=(chain.id, block))) and (
price := price.price
):
logger.debug("found %s block %s price %s in ydb", address, block, price)
_logger_debug("found %s block %s price %s in ydb", address, block, price)
return price


Expand Down Expand Up @@ -80,20 +80,21 @@ async def _set_price(address: str, block: int, price: Decimal) -> None:
if address == constants.EEE_ADDRESS:
address = constants.WRAPPED_GAS_COIN
await ensure_token(str(address), sync=False) # force to string for cache key
with suppress(
InvalidOperation
): # happens with really big numbers sometimes. nbd, we can just skip the cache in this case.
try:
await insert(
type=Price,
block=(chain.id, block),
token=(chain.id, address),
price=Decimal(price),
sync=False,
)
logger.debug("inserted %s block %s price to ydb: %s", address, block, price)
_logger_debug("inserted %s block %s price to ydb: %s", address, block, price)
except InvalidOperation:
# happens with really big numbers sometimes. nbd, we can just skip the cache in this case.
pass


set_price = a_sync.ProcessingQueue(_set_price, num_workers=50, return_data=False)
set_price = ProcessingQueue(_set_price, num_workers=50, return_data=False)


@cached(TTLCache(maxsize=1_000, ttl=5 * 60), lock=threading.Lock())
Expand Down
33 changes: 15 additions & 18 deletions y/classes/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import asyncio
from contextlib import suppress
from abc import abstractmethod
from asyncio import Task, create_task, ensure_future, gather, get_event_loop
from decimal import Decimal
from functools import cached_property
from logging import getLogger
Expand Down Expand Up @@ -211,15 +210,17 @@ class ERC20(ContractBase):

def __repr__(self) -> str:
cls = type(self).__name__
with suppress(AttributeError):
try:
if ERC20.symbol.has_cache_value(self):
symbol = ERC20.symbol.get_cache_value(self)
return f"<{cls} {symbol} '{self.address}'>"
elif not asyncio.get_event_loop().is_running() and not self.asynchronous:
elif not get_event_loop().is_running() and not self.asynchronous:
try:
return f"<{cls} {self.__symbol__(sync=True)} '{self.address}'>"
except NonStandardERC20:
return f"<{cls} SYMBOL_INVALID '{self.address}'>"
except AttributeError:
pass
return f"<{cls} SYMBOL_NOT_LOADED '{self.address}'>"

@a_sync.aka.cached_property
Expand Down Expand Up @@ -371,7 +372,7 @@ async def total_supply_readable(self, block: Optional[Block] = None) -> float:
>>> await token.total_supply_readable()
1000.0
"""
total_supply, scale = await asyncio.gather(
total_supply, scale = await gather(
self.total_supply(block=block, sync=False), self.__scale__
)
return total_supply / scale
Expand Down Expand Up @@ -399,7 +400,7 @@ async def balance_of(
async def balance_of_readable(
self, address: AnyAddressType, block: Optional[Block] = None
) -> float:
balance, scale = await asyncio.gather(
balance, scale = await gather(
self.balance_of(address, block=block, asynchronous=True), self.__scale__
)
return balance / scale
Expand Down Expand Up @@ -914,7 +915,7 @@ async def value_usd(self) -> Decimal:
"""
if self.balance == 0:
return 0
balance, price = await asyncio.gather(self.__readable__, self.__price__)
balance, price = await gather(self.__readable__, self.__price__)
value = balance * price
self._logger.debug("balance: %s price: %s value: %s", balance, price, value)
return value
Expand Down Expand Up @@ -969,7 +970,7 @@ def __await__(self) -> Generator[Any, None, Literal[True]]:
"""
return self.loaded.__await__()

@abc.abstractmethod
@abstractmethod
async def _load(self) -> NoReturn:
"""
`self._load` is the coro that will run in the daemon task associated with this _Loader.
Expand All @@ -995,7 +996,7 @@ def loaded(self) -> Awaitable[Literal[True]]:
return self._loaded.wait()

@property
def _task(self) -> "asyncio.Task[NoReturn]":
def _task(self) -> "Task[NoReturn]":
"""
The task that runs `self._load()` for this `_Loader`.
Expand All @@ -1014,13 +1015,11 @@ def _task(self) -> "asyncio.Task[NoReturn]":
raise type(self.__exc)(*self.__exc.args).with_traceback(self.__tb)
if self.__task is None:
logger.debug("creating loader task for %s", self)
self.__task = asyncio.create_task(
coro=self.__load(), name=f"{self}.__load()"
)
self.__task = create_task(coro=self.__load(), name=f"{self}.__load()")
self.__task.add_done_callback(self._done_callback)
return self.__task

def _done_callback(self, task: "asyncio.Task[Any]") -> None:
def _done_callback(self, task: "Task[Any]") -> None:
"""
Called on `self._task` when it completes, if applicable.
Expand Down Expand Up @@ -1063,7 +1062,7 @@ class _EventsLoader(_Loader):
"""

@property
@abc.abstractmethod
@abstractmethod
def _events(self) -> "Events":
"""
The Events object associated with this _EventsLoader.
Expand All @@ -1084,9 +1083,7 @@ def loaded(self) -> Awaitable[Literal[True]]:
"""
self._task # ensure task is running and not err'd
if self._loaded is None:
self._loaded = asyncio.ensure_future(
self._events._lock.wait_for(self._init_block)
)
self._loaded = ensure_future(self._events._lock.wait_for(self._init_block))
return self._loaded

async def _load(self) -> NoReturn:
Expand Down
17 changes: 9 additions & 8 deletions y/prices/yearn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from contextlib import suppress
from asyncio import gather
from decimal import Decimal
from logging import getLogger
from typing import Optional, Tuple

import a_sync
Expand All @@ -28,7 +27,7 @@
from y.utils.logging import get_price_logger
from y.utils.raw_calls import raw_call

logger = logging.getLogger(__name__)
logger = getLogger(__name__)

# NOTE: Yearn and Yearn-inspired

Expand Down Expand Up @@ -80,7 +79,7 @@ async def is_yearn_vault(token: AnyAddressType) -> bool:

# Yearn-like contracts can use these formats
result = any(
await asyncio.gather(
await gather(
has_methods(
token,
(
Expand All @@ -101,8 +100,11 @@ async def is_yearn_vault(token: AnyAddressType) -> bool:
# pricePerShare can revert if totalSupply == 0, which would cause `has_methods` to return `False`,
# but it might still be a vault. This section will correct `result` for problematic vaults.
if not result:
with suppress(ContractNotVerified, MessedUpBrownieContract):
try:
contract = await Contract.coroutine(token)
except (ContractNotVerified, MessedUpBrownieContract):
pass
else:
result = any(
[
hasattr(contract, "pricePerShare"),
Expand All @@ -112,7 +114,6 @@ async def is_yearn_vault(token: AnyAddressType) -> bool:
hasattr(contract, "convertToAssets"),
]
)

return result


Expand Down Expand Up @@ -337,7 +338,7 @@ async def price(
"""
logger = get_price_logger(self.address, block=None, extra="yearn")
underlying: ERC20
share_price, underlying = await asyncio.gather(
share_price, underlying = await gather(
self.share_price(block=block, sync=False), self.__underlying__
)
if share_price is None:
Expand Down

0 comments on commit acd9a7f

Please sign in to comment.