Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding __aenter__ and __aexit__ Methods to AsyncHTMLSession #556

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions requests_html.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import sys
import asyncio
import pyppeteer
import requests
import http.cookiejar
import lxml

from typing import Set, Union, List, MutableMapping, Optional

from urllib.parse import urlparse, urlunparse, urljoin
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import TimeoutError
from functools import partial
from typing import Set, Union, List, MutableMapping, Optional

import pyppeteer
import requests
import http.cookiejar
from pyquery import PyQuery

from fake_useragent import UserAgent
from lxml.html.clean import Cleaner
import lxml
from lxml import etree
from lxml.html import HtmlElement
from lxml.html import tostring as lxml_html_tostring
Expand Down Expand Up @@ -771,7 +772,6 @@ def __init__(self, mock_browser : bool = True, verify : bool = True,

self.__browser_args = browser_args


def response_hook(self, response, **kwargs) -> HTMLResponse:
""" Change response encoding and replace it by a HTMLResponse. """
if not response.encoding:
Expand Down Expand Up @@ -823,6 +823,12 @@ def __init__(self, loop=None, workers=None,
self.loop = loop or asyncio.get_event_loop()
self.thread_pool = ThreadPoolExecutor(max_workers=workers)

async def __aenter__(self) -> 'AsyncHTMLSession':
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()

def request(self, *args, **kwargs):
""" Partial original request func and run it in a thread. """
func = partial(super().request, *args, **kwargs)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_requests_html.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import pytest

from functools import partial

import pytest
from pyppeteer.browser import Browser
from pyppeteer.page import Page
from requests_html import HTMLSession, AsyncHTMLSession, HTML
from requests_file import FileAdapter

Expand Down Expand Up @@ -322,3 +322,25 @@ async def test_async_browser_session():
browser = await session.browser
assert isinstance(browser, Browser)
await session.close()


@pytest.mark.asyncio
async def test_async_context_manager():
"""
Test the behavior of the async context manager for AsyncHTMLSession.

This test case validates that the AsyncHTMLSession instance can be used
as an asynchronous context manager, and the session can successfully make
an HTTP GET request within the context.

Note: If the user has no connection, a ConnectionError may occur, and the
test will be skipped.

"""
async with AsyncHTMLSession() as s:
try:
results = await s.get('https://www.google.com')
assert results.status_code == 200
except ConnectionError:
# if the user has no connection skip this test
pass