Source code for splinter.driver.djangoclient

# Copyright 2012 splinter authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
from typing import Optional
from urllib import parse

from .lxmldriver import LxmlDriver
from splinter.abc import CookieManagerAPI
from splinter.config import Config
from splinter.request_handler.status_code import StatusCode


class CookieManager(CookieManagerAPI):
    def add(self, cookie, **kwargs):
        for key, value in cookie.items():
            self.driver.cookies[key] = value
            for k, v in kwargs.items():
                self.driver.cookies[key][k] = v

    def delete(self, *cookies):
        if cookies:
            for cookie in cookies:
                try:
                    del self.driver.cookies[cookie]
                except KeyError:
                    pass

    def delete_all(self):
        self.driver.cookies.clear()

    def all(self, verbose=False):  # NOQA: A003
        cookies = {}
        for key, value in self.driver.cookies.items():
            cookies[key] = value
        return cookies

    def __getitem__(self, item):
        return self.driver.cookies[item].value

    def __contains__(self, key):
        return key in self.driver.cookies

    def __eq__(self, other_object):
        if isinstance(other_object, dict):
            cookies_dict = {key: morsel.value for key, morsel in self.driver.cookies.items()}
            return cookies_dict == other_object
        return False


[docs] class DjangoClient(LxmlDriver): driver_name = "django" def __init__( self, user_agent=None, wait_time=2, config: Optional[Config] = None, **kwargs, ): from django.test.client import Client self._custom_headers = kwargs.pop("custom_headers", {}) client_kwargs = {} for key, value in kwargs.items(): if key.startswith("client_"): client_kwargs[key.replace("client_", "")] = value self._browser = Client(**client_kwargs) self._cookie_manager = CookieManager(self._browser) super().__init__( wait_time=wait_time, user_agent=user_agent, config=config, ) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): pass def _post_load(self): self._forms = {} try: del self._html except AttributeError: pass self.status_code = StatusCode(self._response.status_code, "") def _handle_redirect_chain(self): if self._response.redirect_chain: for redirect_url, redirect_code in self._response.redirect_chain: self._last_urls.append(redirect_url) self._url = self._last_urls[-1] def _set_extra_params(self, url): extra = {} components = parse.urlparse(url) if components.hostname: extra.update({"SERVER_NAME": components.hostname}) if components.port: extra.update({"SERVER_PORT": components.port}) if self.config.user_agent: extra.update({"User-Agent": self._user_agent}) if self._custom_headers: extra.update(self._custom_headers) return extra def _do_method(self, method, url, data=None, record_url=True): self._url = url extra = self._set_extra_params(url) func_method = getattr(self._browser, method.lower()) self._response = func_method(url, data=data, follow=True, **extra) if record_url: self._last_url_index += 1 # Going to a new URL always crops the url history self._last_urls = self._last_urls[: self._last_url_index] self._last_urls.append(url) self._handle_redirect_chain() self._post_load() def submit_data(self, form): return super().submit(form).content @property def html(self): return self._response.content.decode(self._response._charset or "utf-8")