You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
279 lines
10 KiB
279 lines
10 KiB
import time
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, Generator, Optional, Type, cast
|
|
from weakref import ReferenceType, ref
|
|
|
|
from aiohttp import ClientResponse, ClientSession
|
|
from sqlalchemy import JSON, BigInteger, Column, Enum as SQLEnum, Integer, Text, and_, create_engine
|
|
from sqlalchemy.engine.base import Engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import Query, Session, sessionmaker
|
|
|
|
from hummingbot.model.transaction_base import TransactionBase
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class HttpRequestMethod(Enum):
|
|
POST = 1
|
|
GET = 2
|
|
PUT = 3
|
|
PATCH = 4
|
|
DELETE = 5
|
|
|
|
|
|
class HttpRequestType(Enum):
|
|
PLAIN = 1
|
|
WITH_PARAMS = 2
|
|
WITH_JSON = 3
|
|
|
|
|
|
class HttpResponseType(Enum):
|
|
HEADER_ONLY = 1
|
|
WITH_TEXT = 2
|
|
WITH_JSON = 3
|
|
|
|
|
|
class HttpPlayback(Base):
|
|
__tablename__ = "HttpPlayback"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
timestamp = Column(BigInteger, nullable=False)
|
|
url = Column(Text, index=True, nullable=False)
|
|
method = Column(SQLEnum(HttpRequestMethod), nullable=False)
|
|
request_type = Column(SQLEnum(HttpRequestType), nullable=False)
|
|
request_params = Column(JSON)
|
|
request_json = Column(JSON)
|
|
response_type = Column(SQLEnum(HttpResponseType), nullable=False)
|
|
response_code = Column(Integer, nullable=False)
|
|
response_text = Column(Text)
|
|
response_json = Column(JSON)
|
|
|
|
|
|
class HttpRecorderClientResponse(ClientResponse):
|
|
_database_id: Optional[int]
|
|
_parent_recorder_ref: Optional[ReferenceType]
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._database_id = None
|
|
self._parent_recorder_ref = None
|
|
|
|
@property
|
|
def database_id(self) -> Optional[int]:
|
|
return self._database_id
|
|
|
|
@database_id.setter
|
|
def database_id(self, value: int):
|
|
self._database_id = value
|
|
|
|
@property
|
|
def parent_recorder(self) -> Optional["HttpRecorder"]:
|
|
if self._parent_recorder_ref is not None:
|
|
return self._parent_recorder_ref()
|
|
return None
|
|
|
|
@parent_recorder.setter
|
|
def parent_recorder(self, value: "HttpRecorder"):
|
|
self._parent_recorder_ref = ref(value)
|
|
|
|
def get_playback_entry(self, session: Session) -> HttpPlayback:
|
|
return session.query(HttpPlayback).filter(HttpPlayback.id == self.database_id).one()
|
|
|
|
async def text(self, *args, **kwargs) -> str:
|
|
response_text: str = await super().text(*args, **kwargs)
|
|
with self.parent_recorder.begin() as session:
|
|
session: Session = session
|
|
playback_entry: HttpPlayback = self.get_playback_entry(session)
|
|
playback_entry.response_text = HttpResponseType.WITH_TEXT
|
|
playback_entry.response_text = response_text
|
|
return response_text
|
|
|
|
async def json(self, *args, **kwargs) -> Any:
|
|
response_obj: Any = await super().json(*args, **kwargs)
|
|
with self.parent_recorder.begin() as session:
|
|
session: Session = session
|
|
playback_entry: HttpPlayback = self.get_playback_entry(session)
|
|
playback_entry.response_type = HttpResponseType.WITH_JSON
|
|
playback_entry.response_json = response_obj
|
|
return response_obj
|
|
|
|
|
|
class HttpPlayerBase(TransactionBase):
|
|
def __init__(self, db_path: str):
|
|
self._db_path: str = db_path
|
|
self._db_engine: Engine = create_engine(f"sqlite:///{db_path}")
|
|
self._session_factory: Callable[[], Session] = sessionmaker(bind=self._db_engine)
|
|
Base.metadata.create_all(self._db_engine)
|
|
|
|
def get_new_session(self) -> Session:
|
|
return self._session_factory()
|
|
|
|
@contextmanager
|
|
def patch_aiohttp_client(self) -> Generator[Type[ClientSession], None, None]:
|
|
try:
|
|
ClientSession._original_request_func = ClientSession._request
|
|
ClientSession._request = lambda s, *args, **kwargs: self.aiohttp_request_method(s, *args, **kwargs)
|
|
yield ClientSession
|
|
finally:
|
|
ClientSession._request = ClientSession._original_request_func
|
|
del ClientSession._original_request_func
|
|
|
|
|
|
class HttpRecorder(HttpPlayerBase):
|
|
"""
|
|
Records HTTP conversations made over any aiohttp.ClientSession object, and records them to an SQLite database file
|
|
for replaying.
|
|
|
|
Usage:
|
|
recorder = HttpRecorder('test.db')
|
|
with recorder.patch_aiohttp_client:
|
|
# all aiohttp conversations inside this block will be recorded to test.db
|
|
async with aiohttp.ClientSession() as client:
|
|
async with client.get("https://api.binance.com/api/v3/time") as resp:
|
|
data = await resp.json() # the request and response are recorded to test.db
|
|
...
|
|
"""
|
|
async def aiohttp_request_method(
|
|
self,
|
|
client: ClientSession,
|
|
method: str,
|
|
url: str,
|
|
**kwargs) -> HttpRecorderClientResponse:
|
|
try:
|
|
if hasattr(client, "_reentrant_ref_count"):
|
|
client._reentrant_ref_count += 1
|
|
else:
|
|
client._reentrant_ref_count = 1
|
|
client._original_response_class = client._response_class
|
|
client._response_class = HttpRecorderClientResponse
|
|
request_type: HttpRequestType = HttpRequestType.PLAIN
|
|
request_params: Optional[Dict[str, str]] = None
|
|
request_json: Optional[Any] = None
|
|
if "params" in kwargs:
|
|
request_type = HttpRequestType.WITH_PARAMS
|
|
request_params = kwargs.get("params")
|
|
if "json" in kwargs:
|
|
request_type = HttpRequestType.WITH_JSON
|
|
request_json = kwargs.get("json")
|
|
response: HttpRecorderClientResponse = await client._original_request_func(method, url, **kwargs)
|
|
response.parent_recorder = self
|
|
with self.begin() as session:
|
|
session: Session = session
|
|
playback_entry: HttpPlayback = HttpPlayback(
|
|
timestamp=int(time.time() * 1e3),
|
|
url=url,
|
|
method=method,
|
|
request_type=request_type,
|
|
request_params=request_params,
|
|
request_json=request_json,
|
|
response_type=HttpResponseType.HEADER_ONLY,
|
|
response_code=response.status
|
|
)
|
|
session.add(playback_entry)
|
|
session.flush()
|
|
response.database_id = playback_entry.id
|
|
return response
|
|
finally:
|
|
client._reentrant_ref_count -= 1
|
|
if client._reentrant_ref_count < 1:
|
|
client._response_class = client._original_response_class
|
|
del client._original_response_class
|
|
del client._reentrant_ref_count
|
|
|
|
|
|
class HttpPlayerResponse:
|
|
def __init__(self, method: str, url: str, status: int, response_text: Optional[str], response_json: Optional[Any]):
|
|
self.method = method
|
|
self.url = url
|
|
self.status = status
|
|
self._response_text: Optional[str] = response_text
|
|
self._response_json: Optional[Any] = response_json
|
|
|
|
async def text(self) -> str:
|
|
if self._response_text is None:
|
|
raise EnvironmentError("No response text has been recorded for replaying.")
|
|
return self._response_text
|
|
|
|
async def json(self) -> Any:
|
|
if self._response_json is None:
|
|
raise EnvironmentError("No response json has been recorded for replaying.")
|
|
return self._response_json
|
|
|
|
def release(self):
|
|
"""
|
|
This is needed to satisfy ClientSession logic.
|
|
"""
|
|
pass
|
|
|
|
|
|
class HttpPlayer(HttpPlayerBase):
|
|
"""
|
|
Given a HTTP conversation record db, patch aiohttp.ClientSession such that it will only replay matched recorded
|
|
conversations.
|
|
|
|
When aiohttp.ClientSession makes any request inside `patch_aiohttp_client()`, the player will search for a matching
|
|
response by URL, request params and request JSON. If no matching response is found, then an exception will be
|
|
raised.
|
|
|
|
Usage:
|
|
recorder = HttpPlayer('test.db')
|
|
with recorder.patch_aiohttp_client:
|
|
# all aiohttp responses within this block will be replays from past records in test.db.
|
|
async with aiohttp.ClientSession() as client:
|
|
async with client.get("https://api.binance.com/api/v3/time") as resp:
|
|
data = await resp.json() # the data returned will be the recorded response
|
|
...
|
|
"""
|
|
_replay_timestamp_ms: Optional[int]
|
|
|
|
def __init__(self, db_path: str):
|
|
super().__init__(db_path)
|
|
self._replay_timestamp_ms = None
|
|
|
|
@property
|
|
def replay_timestamp_ms(self) -> Optional[int]:
|
|
return self._replay_timestamp_ms
|
|
|
|
@replay_timestamp_ms.setter
|
|
def replay_timestamp_ms(self, value: Optional[int]):
|
|
self._replay_timestamp_ms = value
|
|
|
|
async def aiohttp_request_method(
|
|
self,
|
|
_: ClientSession,
|
|
method: str,
|
|
url: str,
|
|
**kwargs) -> HttpPlayerResponse:
|
|
with self.begin() as session:
|
|
session: Session = session
|
|
query: Query = (HttpPlayback.url == url)
|
|
query = cast(Query, and_(query, HttpPlayback.method == method))
|
|
if "params" in kwargs:
|
|
query = cast(Query, and_(query, HttpPlayback.request_params == kwargs["params"]))
|
|
if "json" in kwargs:
|
|
query = cast(Query, and_(query, HttpPlayback.request_json == kwargs["json"]))
|
|
if self._replay_timestamp_ms is not None:
|
|
query = cast(Query, and_(query, HttpPlayback.timestamp >= self._replay_timestamp_ms))
|
|
playback_entry: Optional[HttpPlayback] = (
|
|
session.query(HttpPlayback).filter(query).first()
|
|
)
|
|
|
|
# Loosen the query conditions if the first, precise query didn't work.
|
|
if playback_entry is None:
|
|
query = (HttpPlayback.url == url)
|
|
query = cast(Query, and_(query, HttpPlayback.method == method))
|
|
if self._replay_timestamp_ms is not None:
|
|
query = cast(Query, and_(query, HttpPlayback.timestamp >= self._replay_timestamp_ms))
|
|
playback_entry = (
|
|
session.query(HttpPlayback).filter(query).first()
|
|
)
|
|
|
|
return HttpPlayerResponse(
|
|
method,
|
|
url,
|
|
playback_entry.response_code,
|
|
playback_entry.response_text,
|
|
playback_entry.response_json
|
|
)
|
|
|