Source code for splinter.driver.flaskclient
# Copyright 2014 splinter authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
from typing import Optional
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
from .lxmldriver import LxmlDriver
from splinter.abc import CookieManagerAPI
from splinter.config import Config
from splinter.request_handler.status_code import StatusCode
class CookieManager(CookieManagerAPI):
def add(self, cookie, **kwargs):
for key, value in cookie.items():
self.driver.set_cookie(
key=key,
value=value,
domain="localhost",
**kwargs,
)
def delete(self, *cookies):
if cookies:
for cookie in cookies:
try:
self.driver.delete_cookie(cookie)
except KeyError:
pass
def delete_all(self):
self.driver._cookies.clear()
def all(self, verbose=False): # NOQA: A003
cookies = {}
for cookie in self.driver._cookies.values():
cookies[cookie.key] = cookie.value
return cookies
def __getitem__(self, item):
return self.driver.get_cookie(item).value
def __contains__(self, key):
for cookie in self.driver._cookies.values():
if cookie.key == key:
return True
return False
def __eq__(self, other_object):
if isinstance(other_object, dict):
cookies_dict = {c.key: c.value for c in self.driver._cookies.values()}
return cookies_dict == other_object
return False
[docs]
class FlaskClient(LxmlDriver):
driver_name = "flask"
def __init__(
self,
app,
user_agent=None,
wait_time=2,
custom_headers=None,
config: Optional[Config] = None,
):
app.config["TESTING"] = True
self._browser = app.test_client()
self._cookie_manager = CookieManager(self._browser)
self._custom_headers = custom_headers if custom_headers else {}
super().__init__(wait_time=wait_time)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def _post_load(self):
self._forms = {}
try:
del self._html
except AttributeError:
pass
self.status_code = StatusCode(self._response.status_code, "")
def _do_method(self, method, url, data=None, record_url=True):
# Set the initial URL and client/HTTP method
self._url = url
func_method = getattr(self._browser, method.lower())
# Continue to make requests until a non 30X response is received
while True:
if record_url:
self._last_url_index += 1
# Going to a new URL always crops the url history
self._last_urls = self._last_urls[: self._last_url_index]
self._last_urls.append(url)
# If we're making a GET request set the data against the URL as a
# query.
if method.lower() == "get":
# Parse the existing URL and it's query
url_parts = urlparse(url)
url_params = parse_qs(url_parts.query)
# Update any existing query dictionary with the `data` argument
url_params.update(data or {})
url_parts = url_parts._replace(query=urlencode(url_params, doseq=True))
# Rebuild the URL
url = urlunparse(url_parts)
# As the `data` argument will be passed as a keyword argument to
# the `func_method` we set it `None` to prevent it populating
# `flask.request.form` on `GET` requests.
data = None
# Call the flask client
self._response = func_method(
url,
headers=self._custom_headers,
data=data,
follow_redirects=False,
)
# Implement more standard `302`/`303` behaviour
if self._response.status_code in (302, 303):
data = None
func_method = getattr(self._browser, "get")
# If the response was not in the `30X` range we're done
if self._response.status_code not in (301, 302, 303, 305, 307):
break
# If the response was in the `30X` range get next URL to request
url = self._response.headers["Location"]
self._url = url
self._post_load()
def submit_data(self, form):
return super().submit(form).data
@property
def html(self):
return self._response.get_data(as_text=True)