initial commit

This commit is contained in:
2026-06-25 21:29:21 +00:00
commit 0d0a7456de
2738 changed files with 542622 additions and 0 deletions
+19
View File
@@ -0,0 +1,19 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
__version__ = "4.43.0"
@@ -0,0 +1,90 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.common.exceptions import (
DetachedShadowRootException,
ElementClickInterceptedException,
ElementNotInteractableException,
ElementNotSelectableException,
ElementNotVisibleException,
ImeActivationFailedException,
ImeNotAvailableException,
InsecureCertificateException,
InvalidArgumentException,
InvalidCookieDomainException,
InvalidCoordinatesException,
InvalidElementStateException,
InvalidSelectorException,
InvalidSessionIdException,
InvalidSwitchToTargetException,
JavascriptException,
MoveTargetOutOfBoundsException,
NoAlertPresentException,
NoSuchAttributeException,
NoSuchCookieException,
NoSuchDriverException,
NoSuchElementException,
NoSuchFrameException,
NoSuchShadowRootException,
NoSuchWindowException,
ScreenshotException,
SessionNotCreatedException,
StaleElementReferenceException,
TimeoutException,
UnableToSetCookieException,
UnexpectedAlertPresentException,
UnexpectedTagNameException,
UnknownMethodException,
WebDriverException,
)
__all__ = [
"DetachedShadowRootException",
"ElementClickInterceptedException",
"ElementNotInteractableException",
"ElementNotSelectableException",
"ElementNotVisibleException",
"ImeActivationFailedException",
"ImeNotAvailableException",
"InsecureCertificateException",
"InvalidArgumentException",
"InvalidCookieDomainException",
"InvalidCoordinatesException",
"InvalidElementStateException",
"InvalidSelectorException",
"InvalidSessionIdException",
"InvalidSwitchToTargetException",
"JavascriptException",
"MoveTargetOutOfBoundsException",
"NoAlertPresentException",
"NoSuchAttributeException",
"NoSuchCookieException",
"NoSuchDriverException",
"NoSuchElementException",
"NoSuchFrameException",
"NoSuchShadowRootException",
"NoSuchWindowException",
"ScreenshotException",
"SessionNotCreatedException",
"StaleElementReferenceException",
"TimeoutException",
"UnableToSetCookieException",
"UnexpectedAlertPresentException",
"UnexpectedTagNameException",
"UnknownMethodException",
"WebDriverException",
]
@@ -0,0 +1,308 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Exceptions that may happen in all the webdriver code."""
from collections.abc import Sequence
from typing import Any
SUPPORT_MSG = "For documentation on this error, please visit:"
ERROR_URL = "https://www.selenium.dev/documentation/webdriver/troubleshooting/errors"
class WebDriverException(Exception):
"""Base webdriver exception."""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
super().__init__()
self.msg = msg
self.screen = screen
self.stacktrace = stacktrace
def __str__(self) -> str:
exception_msg = f"Message: {self.msg}\n"
if self.screen:
exception_msg += "Screenshot: available via screen\n"
if self.stacktrace:
stacktrace = "\n".join(self.stacktrace)
exception_msg += f"Stacktrace:\n{stacktrace}"
return exception_msg
class InvalidSwitchToTargetException(WebDriverException):
"""Thrown when frame or window target to be switched doesn't exist."""
class NoSuchFrameException(InvalidSwitchToTargetException):
"""Thrown when frame target to be switched doesn't exist."""
class NoSuchWindowException(InvalidSwitchToTargetException):
"""Thrown when window target to be switched doesn't exist.
To find the current set of active window handles, you can get a list
of the active window handles in the following way::
print driver.window_handles
"""
class NoSuchElementException(WebDriverException):
"""Thrown when element could not be found.
If you encounter this exception, you may want to check the following:
* Check your selector used in your find_by...
* Element may not yet be on the screen at the time of the find operation,
(webpage is still loading) see selenium.webdriver.support.wait.WebDriverWait()
for how to write a wait wrapper to wait for an element to appear.
"""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#nosuchelementexception"
super().__init__(with_support, screen, stacktrace)
class NoSuchAttributeException(WebDriverException):
"""Thrown when the attribute of element could not be found.
You may want to check if the attribute exists in the particular
browser you are testing against. Some browsers may have different
property names for the same property. (IE8's .innerText vs. Firefox
.textContent)
"""
class NoSuchShadowRootException(WebDriverException):
"""Thrown when trying to access the shadow root of an element when it does not have a shadow root attached."""
class StaleElementReferenceException(WebDriverException):
"""Thrown when a reference to an element is now "stale".
Stale means the element no longer appears on the DOM of the page.
Possible causes of StaleElementReferenceException include, but not limited to:
* You are no longer on the same page, or the page may have refreshed since the element
was located.
* The element may have been removed and re-added to the screen, since it was located.
Such as an element being relocated.
This can happen typically with a javascript framework when values are updated and the
node is rebuilt.
* Element may have been inside an iframe or another context which was refreshed.
"""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception"
super().__init__(with_support, screen, stacktrace)
class InvalidElementStateException(WebDriverException):
"""Thrown when a command could not be completed because the element is in an invalid state.
This can be caused by attempting to clear an element that isn't both editable and resettable.
"""
class UnexpectedAlertPresentException(WebDriverException):
"""Thrown when an unexpected alert has appeared.
Usually raised when an unexpected modal is blocking the webdriver
from executing commands.
"""
def __init__(
self,
msg: Any | None = None,
screen: str | None = None,
stacktrace: Sequence[str] | None = None,
alert_text: str | None = None,
) -> None:
super().__init__(msg, screen, stacktrace)
self.alert_text = alert_text
def __str__(self) -> str:
return f"Alert Text: {self.alert_text}\n{super().__str__()}"
class NoAlertPresentException(WebDriverException):
"""Thrown when switching to no presented alert.
This can be caused by calling an operation on the Alert() class when
an alert is not yet on the screen.
"""
class ElementNotVisibleException(InvalidElementStateException):
"""Thrown when an element is present on the DOM, but it is not visible, and so is not able to be interacted with.
Most commonly encountered when trying to click or read text of an element that is hidden from view.
"""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotvisibleexception"
super().__init__(with_support, screen, stacktrace)
class ElementNotInteractableException(InvalidElementStateException):
"""Thrown when element interactions will hit another element due to paint order."""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception"
super().__init__(with_support, screen, stacktrace)
class ElementNotSelectableException(InvalidElementStateException):
"""Thrown when trying to select an unselectable element.
For example, selecting a 'script' element.
"""
class InvalidCookieDomainException(WebDriverException):
"""Thrown when attempting to add a cookie under a different domain."""
class UnableToSetCookieException(WebDriverException):
"""Thrown when a driver fails to set a cookie."""
class TimeoutException(WebDriverException):
"""Thrown when a command does not complete in enough time."""
class MoveTargetOutOfBoundsException(WebDriverException):
"""Thrown when the target provided to the `ActionsChains` move() method is invalid, i.e. out of document."""
class UnexpectedTagNameException(WebDriverException):
"""Thrown when a support class did not get an expected web element."""
class InvalidSelectorException(WebDriverException):
"""Thrown when the selector used to find an element does not return a WebElement.
Currently this only happens when the XPath expression is syntactically invalid or does not select WebElements.
"""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidselectorexception"
super().__init__(with_support, screen, stacktrace)
class ImeNotAvailableException(WebDriverException):
"""Thrown when IME support is not available.
This exception is thrown for every IME-related method call if IME
support is not available on the machine.
"""
class ImeActivationFailedException(WebDriverException):
"""Thrown when activating an IME engine has failed."""
class InvalidArgumentException(WebDriverException):
"""The arguments passed to a command are either invalid or malformed."""
class JavascriptException(WebDriverException):
"""An error occurred while executing JavaScript supplied by the user."""
class NoSuchCookieException(WebDriverException):
"""Thrown when no cookie matching the given path name was found."""
class ScreenshotException(WebDriverException):
"""A screen capture was made impossible."""
class ElementClickInterceptedException(WebDriverException):
"""Thrown when element click fails because another element obscures it."""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception"
super().__init__(with_support, screen, stacktrace)
class InsecureCertificateException(WebDriverException):
"""Thrown when the user agent hits a certificate warning (expired or invalid TLS certificate)."""
class InvalidCoordinatesException(WebDriverException):
"""The coordinates provided to an interaction's operation are invalid."""
class InvalidSessionIdException(WebDriverException):
"""Thrown when the given session id is not in the list of active sessions."""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidsessionidexception"
super().__init__(with_support, screen, stacktrace)
class SessionNotCreatedException(WebDriverException):
"""A new session could not be created."""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#sessionnotcreatedexception"
super().__init__(with_support, screen, stacktrace)
class UnknownMethodException(WebDriverException):
"""The requested command matched a known URL but did not match any methods for that URL."""
class NoSuchDriverException(WebDriverException):
"""Raised when driver is not specified and cannot be located."""
def __init__(
self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None
) -> None:
with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}/driver_location"
super().__init__(with_support, screen, stacktrace)
class DetachedShadowRootException(WebDriverException):
"""Raised when referenced shadow root is no longer attached to the DOM."""
View File
@@ -0,0 +1,110 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import importlib
import logging
import os
# Enable debug logging if SE_DEBUG environment variable is set
if os.environ.get("SE_DEBUG"):
logger = logging.getLogger("selenium")
logger.setLevel(logging.DEBUG)
if not logger.handlers:
logger.addHandler(logging.StreamHandler())
logger.warning(
"Environment Variable `SE_DEBUG` is set; "
"Selenium is forcing verbose logging which may override user-specified settings."
)
__version__ = "4.43.0"
# Lazy import mapping: name -> (module_path, attribute_name)
_LAZY_IMPORTS = {
# Chrome
"Chrome": ("selenium.webdriver.chrome.webdriver", "WebDriver"),
"ChromeOptions": ("selenium.webdriver.chrome.options", "Options"),
"ChromeService": ("selenium.webdriver.chrome.service", "Service"),
# Edge
"Edge": ("selenium.webdriver.edge.webdriver", "WebDriver"),
"ChromiumEdge": ("selenium.webdriver.edge.webdriver", "WebDriver"),
"EdgeOptions": ("selenium.webdriver.edge.options", "Options"),
"EdgeService": ("selenium.webdriver.edge.service", "Service"),
# Firefox
"Firefox": ("selenium.webdriver.firefox.webdriver", "WebDriver"),
"FirefoxOptions": ("selenium.webdriver.firefox.options", "Options"),
"FirefoxProfile": ("selenium.webdriver.firefox.firefox_profile", "FirefoxProfile"),
"FirefoxService": ("selenium.webdriver.firefox.service", "Service"),
# IE
"Ie": ("selenium.webdriver.ie.webdriver", "WebDriver"),
"IeOptions": ("selenium.webdriver.ie.options", "Options"),
"IeService": ("selenium.webdriver.ie.service", "Service"),
# Safari
"Safari": ("selenium.webdriver.safari.webdriver", "WebDriver"),
"SafariOptions": ("selenium.webdriver.safari.options", "Options"),
"SafariService": ("selenium.webdriver.safari.service", "Service"),
# Remote
"Remote": ("selenium.webdriver.remote.webdriver", "WebDriver"),
# WebKitGTK
"WebKitGTK": ("selenium.webdriver.webkitgtk.webdriver", "WebDriver"),
"WebKitGTKOptions": ("selenium.webdriver.webkitgtk.options", "Options"),
"WebKitGTKService": ("selenium.webdriver.webkitgtk.service", "Service"),
# WPEWebKit
"WPEWebKit": ("selenium.webdriver.wpewebkit.webdriver", "WebDriver"),
"WPEWebKitOptions": ("selenium.webdriver.wpewebkit.options", "Options"),
"WPEWebKitService": ("selenium.webdriver.wpewebkit.service", "Service"),
# Common utilities
"ActionChains": ("selenium.webdriver.common.action_chains", "ActionChains"),
"DesiredCapabilities": ("selenium.webdriver.common.desired_capabilities", "DesiredCapabilities"),
"Keys": ("selenium.webdriver.common.keys", "Keys"),
"Proxy": ("selenium.webdriver.common.proxy", "Proxy"),
}
# Submodules that can be lazily imported as modules
_LAZY_SUBMODULES = {
"chrome": "selenium.webdriver.chrome",
"chromium": "selenium.webdriver.chromium",
"common": "selenium.webdriver.common",
"edge": "selenium.webdriver.edge",
"firefox": "selenium.webdriver.firefox",
"ie": "selenium.webdriver.ie",
"remote": "selenium.webdriver.remote",
"safari": "selenium.webdriver.safari",
"support": "selenium.webdriver.support",
"webkitgtk": "selenium.webdriver.webkitgtk",
"wpewebkit": "selenium.webdriver.wpewebkit",
}
def __getattr__(name):
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
module = importlib.import_module(module_path)
value = getattr(module, attr_name)
globals()[name] = value
return value
if name in _LAZY_SUBMODULES:
module = importlib.import_module(_LAZY_SUBMODULES[name])
globals()[name] = module
return module
raise AttributeError(f"module 'selenium.webdriver' has no attribute {name!r}")
def __dir__():
return sorted(set(__all__) | set(_LAZY_SUBMODULES.keys()))
__all__ = sorted(_LAZY_IMPORTS.keys())
@@ -0,0 +1,131 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type stub with lazy import mapping from __init__.py.
This stub file is necessary for type checkers and IDEs to automatically have
visibility into lazy modules since they are not imported immediately at runtime.
"""
# ruff: noqa: I001
# Expose runtime version
__version__: str
# Chrome
from selenium.webdriver.chrome.webdriver import WebDriver as Chrome
from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeService
# Edge
from selenium.webdriver.edge.webdriver import WebDriver as Edge
from selenium.webdriver.edge.webdriver import WebDriver as ChromiumEdge
from selenium.webdriver.edge.options import Options as EdgeOptions
from selenium.webdriver.edge.service import Service as EdgeService
# Firefox
from selenium.webdriver.firefox.webdriver import WebDriver as Firefox
from selenium.webdriver.firefox.options import Options as FirefoxOptions
from selenium.webdriver.firefox.service import Service as FirefoxService
from selenium.webdriver.firefox.firefox_profile import FirefoxProfile
# IE
from selenium.webdriver.ie.webdriver import WebDriver as Ie
from selenium.webdriver.ie.options import Options as IeOptions
from selenium.webdriver.ie.service import Service as IeService
# Safari
from selenium.webdriver.safari.webdriver import WebDriver as Safari
from selenium.webdriver.safari.options import Options as SafariOptions
from selenium.webdriver.safari.service import Service as SafariService
# Remote
from selenium.webdriver.remote.webdriver import WebDriver as Remote
# WebKitGTK
from selenium.webdriver.webkitgtk.webdriver import WebDriver as WebKitGTK
from selenium.webdriver.webkitgtk.options import Options as WebKitGTKOptions
from selenium.webdriver.webkitgtk.service import Service as WebKitGTKService
# WPEWebKit
from selenium.webdriver.wpewebkit.webdriver import WebDriver as WPEWebKit
from selenium.webdriver.wpewebkit.options import Options as WPEWebKitOptions
from selenium.webdriver.wpewebkit.service import Service as WPEWebKitService
# Common utilities
from selenium.webdriver.common.action_chains import ActionChains
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.proxy import Proxy
# Submodules
from . import chrome
from . import chromium
from . import common
from . import edge
from . import firefox
from . import ie
from . import remote
from . import safari
from . import support
from . import webkitgtk
from . import wpewebkit
# Exposed names
__all__ = [
# Classes
"ActionChains",
"Chrome",
"ChromeOptions",
"ChromeService",
"ChromiumEdge",
"DesiredCapabilities",
"Edge",
"EdgeOptions",
"EdgeService",
"Firefox",
"FirefoxOptions",
"FirefoxProfile",
"FirefoxService",
"Ie",
"IeOptions",
"IeService",
"Keys",
"Proxy",
"Remote",
"Safari",
"SafariOptions",
"SafariService",
"WPEWebKit",
"WPEWebKitOptions",
"WPEWebKitService",
"WebKitGTK",
"WebKitGTKOptions",
"WebKitGTKService",
# Submodules
"chrome",
"chromium",
"common",
"edge",
"firefox",
"ie",
"remote",
"safari",
"support",
"webkitgtk",
"wpewebkit",
]
@@ -0,0 +1,32 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import importlib
_LAZY_SUBMODULES = ["options", "remote_connection", "service", "webdriver"]
def __getattr__(name):
if name in _LAZY_SUBMODULES:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
return module
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return sorted(_LAZY_SUBMODULES)
@@ -0,0 +1,26 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Type stub with lazy import mapping from __init__.py.
This stub file is necessary for type checkers and IDEs to automatically have
visibility into lazy modules since they are not imported immediately at runtime.
"""
from . import options, remote_connection, service, webdriver
__all__ = ["options", "remote_connection", "service", "webdriver"]
@@ -0,0 +1,34 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.chromium.options import ChromiumOptions
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
class Options(ChromiumOptions):
@property
def default_capabilities(self) -> dict:
return DesiredCapabilities.CHROME.copy()
def enable_mobile(
self,
android_package: str | None = "com.android.chrome",
android_activity: str | None = None,
device_serial: str | None = None,
) -> None:
super().enable_mobile(android_package, android_activity, device_serial)
@@ -0,0 +1,41 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.chromium.remote_connection import ChromiumRemoteConnection
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.webdriver.remote.client_config import ClientConfig
class ChromeRemoteConnection(ChromiumRemoteConnection):
browser_name = DesiredCapabilities.CHROME["browserName"]
def __init__(
self,
remote_server_addr: str,
keep_alive: bool = True,
ignore_proxy: bool = False,
client_config: ClientConfig | None = None,
) -> None:
super().__init__(
remote_server_addr=remote_server_addr,
vendor_prefix="goog",
browser_name=ChromeRemoteConnection.browser_name,
keep_alive=keep_alive,
ignore_proxy=ignore_proxy,
client_config=client_config,
)
@@ -0,0 +1,72 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Mapping, Sequence
from typing import IO, Any
from selenium.webdriver.chromium import service
class Service(service.ChromiumService):
"""Service class responsible for starting and stopping the chromedriver executable.
Args:
executable_path: Install path of the chromedriver executable, defaults
to `chromedriver`.
port: Port for the service to run on, defaults to 0 where the operating
system will decide.
service_args: (Optional) Sequence of args to be passed to the subprocess
when launching the executable.
log_output: (Optional) int representation of STDOUT/DEVNULL, any IO
instance or String path to file.
env: (Optional) Mapping of environment variables for the new process,
defaults to `os.environ`.
"""
def __init__(
self,
executable_path: str | None = None,
port: int = 0,
service_args: Sequence[str] | None = None,
log_output: int | str | IO[Any] | None = None,
env: Mapping[str, str] | None = None,
**kwargs,
) -> None:
self._service_args = list(service_args or [])
super().__init__(
executable_path=executable_path,
port=port,
service_args=service_args,
log_output=log_output,
env=env,
**kwargs,
)
def command_line_args(self) -> list[str]:
return ["--enable-chrome-logs", f"--port={self.port}"] + self._service_args
@property
def service_args(self) -> Sequence[str]:
"""Returns the sequence of service arguments."""
return self._service_args
@service_args.setter
def service_args(self, value: Sequence[str]):
if isinstance(value, str) or not isinstance(value, Sequence):
raise TypeError("service_args must be a sequence")
self._service_args = list(value)
@@ -0,0 +1,51 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chromium.webdriver import ChromiumDriver
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
class WebDriver(ChromiumDriver):
"""Controls the ChromeDriver and allows you to drive the browser."""
def __init__(
self,
options: Options | None = None,
service: Service | None = None,
keep_alive: bool = True,
) -> None:
"""Creates a new instance of the chrome driver.
Starts the service and then creates new instance of chrome driver.
Args:
options: Instance of Options.
service: Service object for handling the browser driver if you need to pass extra details.
keep_alive: Whether to configure ChromeRemoteConnection to use HTTP keep-alive.
"""
service = service if service else Service()
options = options if options else Options()
super().__init__(
browser_name=DesiredCapabilities.CHROME["browserName"],
vendor_prefix="goog",
options=options,
service=service,
keep_alive=keep_alive,
)
@@ -0,0 +1,16 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@@ -0,0 +1,187 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import os
from typing import BinaryIO
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.webdriver.common.options import ArgOptions
class ChromiumOptions(ArgOptions):
KEY = "goog:chromeOptions"
def __init__(self) -> None:
"""Initialize ChromiumOptions with default settings."""
super().__init__()
self._binary_location: str = ""
self._extension_files: list[str] = []
self._extensions: list[str] = []
self._experimental_options: dict[str, str | int | dict | list[str]] = {}
self._debugger_address: str | None = None
self._enable_webextensions: bool = False
@property
def binary_location(self) -> str:
"""Returns the location of the binary, otherwise an empty string."""
return self._binary_location
@binary_location.setter
def binary_location(self, value: str) -> None:
"""Allows you to set where the chromium binary lives.
Args:
value: Path to the Chromium binary.
"""
if not isinstance(value, str):
raise TypeError(self.BINARY_LOCATION_ERROR)
self._binary_location = value
@property
def debugger_address(self) -> str | None:
"""Returns the address of the remote devtools instance."""
return self._debugger_address
@debugger_address.setter
def debugger_address(self, value: str) -> None:
"""Set the address of the remote devtools instance for active wait connection.
Args:
value: Address of remote devtools instance if any (hostname[:port]).
"""
if not isinstance(value, str):
raise TypeError("Debugger Address must be a string")
self._debugger_address = value
@property
def extensions(self) -> list[str]:
"""Returns a list of encoded extensions that will be loaded."""
def _decode(file_data: BinaryIO) -> str:
# Should not use base64.encodestring() which inserts newlines every
# 76 characters (per RFC 1521). Chromedriver has to remove those
# unnecessary newlines before decoding, causing performance hit.
return base64.b64encode(file_data.read()).decode("utf-8")
encoded_extensions = []
for extension in self._extension_files:
with open(extension, "rb") as f:
encoded_extensions.append(_decode(f))
return encoded_extensions + self._extensions
def add_extension(self, extension: str) -> None:
"""Add the path to an extension to be extracted to ChromeDriver.
Args:
extension: Path to the *.crx file.
"""
if extension:
extension_to_add = os.path.abspath(os.path.expanduser(extension))
if os.path.exists(extension_to_add):
self._extension_files.append(extension_to_add)
else:
raise OSError("Path to the extension doesn't exist")
else:
raise ValueError("argument can not be null")
def add_encoded_extension(self, extension: str) -> None:
"""Add Base64-encoded string with extension data to be extracted to ChromeDriver.
Args:
extension: Base64 encoded string with extension data.
"""
if extension:
self._extensions.append(extension)
else:
raise ValueError("argument can not be null")
@property
def experimental_options(self) -> dict:
"""Returns a dictionary of experimental options for chromium."""
return self._experimental_options
def add_experimental_option(self, name: str, value: str | int | dict | list[str]) -> None:
"""Adds an experimental option which is passed to chromium.
Args:
name: The experimental option name.
value: The option value.
"""
self._experimental_options[name] = value
@property
def enable_webextensions(self) -> bool:
"""Return whether webextension support is enabled for Chromium-based browsers."""
return self._enable_webextensions
@enable_webextensions.setter
def enable_webextensions(self, value: bool) -> None:
"""Enables or disables webextension support for Chromium-based browsers.
Args:
value: True to enable webextension support, False to disable.
Notes:
- When enabled, this automatically adds the required Chromium flags:
- --enable-unsafe-extension-debugging
- --remote-debugging-pipe
- When disabled, this removes BOTH flags listed above, even if they were manually added via add_argument()
before enabling webextensions.
- Enabling --remote-debugging-pipe makes the connection b/w chromedriver
and the browser use a pipe instead of a port, disabling many CDP functionalities
like devtools
"""
self._enable_webextensions = value
if value:
# Add required flags for Chromium webextension support
required_flags = ["--enable-unsafe-extension-debugging", "--remote-debugging-pipe"]
for flag in required_flags:
if flag not in self._arguments:
self.add_argument(flag)
else:
# Remove webextension flags if disabling
flags_to_remove = ["--enable-unsafe-extension-debugging", "--remote-debugging-pipe"]
for flag in flags_to_remove:
if flag in self._arguments:
self._arguments.remove(flag)
def to_capabilities(self) -> dict:
"""Creates a capabilities with all the options that have been set.
Returns:
A dictionary with all set options.
"""
caps = self._caps
chrome_options = self.experimental_options.copy()
if self.mobile_options:
chrome_options.update(self.mobile_options)
chrome_options["extensions"] = self.extensions
if self.binary_location:
chrome_options["binary"] = self.binary_location
chrome_options["args"] = self._arguments
if self.debugger_address:
chrome_options["debuggerAddress"] = self.debugger_address
caps[self.KEY] = chrome_options
return caps
@property
def default_capabilities(self) -> dict:
return DesiredCapabilities.CHROME.copy()
@@ -0,0 +1,59 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.remote.client_config import ClientConfig
from selenium.webdriver.remote.remote_connection import RemoteConnection
class ChromiumRemoteConnection(RemoteConnection):
def __init__(
self,
remote_server_addr: str,
vendor_prefix: str,
browser_name: str,
keep_alive: bool = True,
ignore_proxy: bool = False,
client_config: ClientConfig | None = None,
) -> None:
client_config = client_config or ClientConfig(
remote_server_addr=remote_server_addr, keep_alive=keep_alive, timeout=120
)
super().__init__(
ignore_proxy=ignore_proxy,
client_config=client_config,
)
self.browser_name = browser_name
commands = self._remote_commands(vendor_prefix)
for key, value in commands.items():
self._commands[key] = value
def _remote_commands(self, vendor_prefix):
remote_commands = {
"launchApp": ("POST", "/session/$sessionId/chromium/launch_app"),
"setPermissions": ("POST", "/session/$sessionId/permissions"),
"setNetworkConditions": ("POST", "/session/$sessionId/chromium/network_conditions"),
"getNetworkConditions": ("GET", "/session/$sessionId/chromium/network_conditions"),
"deleteNetworkConditions": ("DELETE", "/session/$sessionId/chromium/network_conditions"),
"executeCdpCommand": ("POST", f"/session/$sessionId/{vendor_prefix}/cdp/execute"),
"getSinks": ("GET", f"/session/$sessionId/{vendor_prefix}/cast/get_sinks"),
"getIssueMessage": ("GET", f"/session/$sessionId/{vendor_prefix}/cast/get_issue_message"),
"setSinkToUse": ("POST", f"/session/$sessionId/{vendor_prefix}/cast/set_sink_to_use"),
"startDesktopMirroring": ("POST", f"/session/$sessionId/{vendor_prefix}/cast/start_desktop_mirroring"),
"startTabMirroring": ("POST", f"/session/$sessionId/{vendor_prefix}/cast/start_tab_mirroring"),
"stopCasting": ("POST", f"/session/$sessionId/{vendor_prefix}/cast/stop_casting"),
}
return remote_commands
@@ -0,0 +1,94 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import sys
from collections.abc import Mapping, Sequence
from typing import IO, Any
from selenium.webdriver.common import service
class ChromiumService(service.Service):
"""Service class responsible for starting and stopping the ChromiumDriver WebDriver instance.
Args:
executable_path: (Optional) Install path of the executable.
port: (Optional) Port for the service to run on, defaults to 0 where the operating system will decide.
service_args: (Optional) Sequence of args to be passed to the subprocess when launching the executable.
log_output: (Optional) int representation of STDOUT/DEVNULL, any IO instance or String path to file.
env: (Optional) Mapping of environment variables for the new process, defaults to `os.environ`.
driver_path_env_key: (Optional) Environment variable to use to get the path to the driver executable.
"""
def __init__(
self,
executable_path: str | None = None,
port: int = 0,
service_args: Sequence[str] | None = None,
log_output: int | str | IO[Any] | None = None,
env: Mapping[str, str] | None = None,
driver_path_env_key: str | None = None,
**kwargs,
) -> None:
self._service_args = list(service_args or [])
driver_path_env_key = driver_path_env_key or "SE_CHROMEDRIVER"
if isinstance(log_output, str):
self._service_args.append(f"--log-path={log_output}")
self.log_output = None
else:
self.log_output = log_output
if os.environ.get("SE_DEBUG"):
has_arg_conflicts = any(x in arg for arg in self._service_args for x in ("log-level", "log-path", "silent"))
has_output_conflict = self.log_output is not None
if has_arg_conflicts or has_output_conflict:
logging.getLogger(__name__).warning(
"Environment Variable `SE_DEBUG` is set; "
"forcing ChromiumDriver --verbose and overriding log-level/log-output/silent settings."
)
if has_arg_conflicts:
self._service_args = [
arg for arg in self._service_args if not any(x in arg for x in ("log-level", "log-path", "silent"))
]
self._service_args.append("--verbose")
self.log_output = sys.stderr
super().__init__(
executable_path=executable_path,
port=port,
env=env,
log_output=self.log_output,
driver_path_env_key=driver_path_env_key,
**kwargs,
)
def command_line_args(self) -> list[str]:
return [f"--port={self.port}"] + self._service_args
@property
def service_args(self) -> Sequence[str]:
"""Returns the sequence of service arguments."""
return self._service_args
@service_args.setter
def service_args(self, value: Sequence[str]):
if isinstance(value, str) or not isinstance(value, Sequence):
raise TypeError("service_args must be a sequence")
self._service_args = list(value)
@@ -0,0 +1,205 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.chromium.options import ChromiumOptions
from selenium.webdriver.chromium.remote_connection import ChromiumRemoteConnection
from selenium.webdriver.chromium.service import ChromiumService
from selenium.webdriver.common.driver_finder import DriverFinder
from selenium.webdriver.common.webdriver import LocalWebDriver
from selenium.webdriver.remote.command import Command
class ChromiumDriver(LocalWebDriver):
"""Control the WebDriver instance of ChromiumDriver and drive the browser."""
def __init__(
self,
browser_name: str,
vendor_prefix: str,
options: ChromiumOptions | None = None,
service: ChromiumService | None = None,
keep_alive: bool = True,
) -> None:
"""Create a new WebDriver instance, start the service, and create new ChromiumDriver instance.
Args:
browser_name: Browser name used when matching capabilities.
vendor_prefix: Company prefix to apply to vendor-specific WebDriver extension commands.
options: Instance of ChromiumOptions.
service: Service object for handling the browser driver if you need to pass extra details.
keep_alive: Whether to configure ChromiumRemoteConnection to use HTTP keep-alive.
"""
self.service = service if service else ChromiumService()
self.options = options if options else ChromiumOptions()
finder = DriverFinder(self.service, self.options)
if finder.get_browser_path():
self.options.binary_location = finder.get_browser_path()
self.options.browser_version = None
self.service.path = self.service.env_path() or finder.get_driver_path()
self.service.start()
executor = ChromiumRemoteConnection(
remote_server_addr=self.service.service_url,
browser_name=browser_name,
vendor_prefix=vendor_prefix,
keep_alive=keep_alive,
ignore_proxy=self.options._ignore_local_proxy,
)
try:
super().__init__(command_executor=executor, options=self.options)
except Exception:
self.quit()
raise
def launch_app(self, id):
"""Launches Chromium app specified by id.
Args:
id: The id of the Chromium app to launch.
"""
return self.execute("launchApp", {"id": id})
def get_network_conditions(self):
"""Gets Chromium network emulation settings.
Returns:
A dict. For example: {'latency': 4, 'download_throughput': 2, 'upload_throughput': 2}
"""
return self.execute("getNetworkConditions")["value"]
def set_network_conditions(self, **network_conditions) -> None:
"""Sets Chromium network emulation settings.
Args:
**network_conditions: A dict with conditions specification.
Example:
driver.set_network_conditions(
offline=False,
latency=5, # additional latency (ms)
download_throughput=500 * 1024, # maximal throughput
upload_throughput=500 * 1024,
) # maximal throughput
Note: `throughput` can be used to set both (for download and upload).
"""
self.execute("setNetworkConditions", {"network_conditions": network_conditions})
def delete_network_conditions(self) -> None:
"""Resets Chromium network emulation settings."""
self.execute("deleteNetworkConditions")
def set_permissions(self, name: str, value: str) -> None:
"""Sets Applicable Permission.
Args:
name: The item to set the permission on.
value: The value to set on the item
Example:
driver.set_permissions("clipboard-read", "denied")
"""
self.execute("setPermissions", {"descriptor": {"name": name}, "state": value})
def execute_cdp_cmd(self, cmd: str, cmd_args: dict):
"""Execute Chrome Devtools Protocol command and get returned result.
The command and command args should follow chrome devtools protocol domains/commands
See:
- https://chromedevtools.github.io/devtools-protocol/
Args:
cmd: A str, command name
cmd_args: A dict, command args. empty dict {} if there is no command args
Example:
`driver.execute_cdp_cmd('Network.getResponseBody', {'requestId': requestId})`
Returns:
A dict, empty dict {} if there is no result to return.
For example to getResponseBody:
{'base64Encoded': False, 'body': 'response body string'}
"""
return super().execute_cdp_cmd(cmd, cmd_args)
def get_sinks(self) -> list:
"""Get a list of sinks available for Cast."""
return self.execute("getSinks")["value"]
def get_issue_message(self):
"""Returns an error message when there is any issue in a Cast session."""
return self.execute("getIssueMessage")["value"]
@property
def log_types(self):
"""Gets a list of the available log types.
Example:
--------
>>> driver.log_types
"""
return self.execute(Command.GET_AVAILABLE_LOG_TYPES)["value"]
def get_log(self, log_type):
"""Gets the log for a given log type.
Args:
log_type: Type of log that which will be returned
Example:
>>> driver.get_log("browser")
>>> driver.get_log("driver")
>>> driver.get_log("client")
>>> driver.get_log("server")
"""
return self.execute(Command.GET_LOG, {"type": log_type})["value"]
def set_sink_to_use(self, sink_name: str) -> dict:
"""Set a specific sink as a Cast session receiver target.
Args:
sink_name: Name of the sink to use as the target.
"""
return self.execute("setSinkToUse", {"sinkName": sink_name})
def start_desktop_mirroring(self, sink_name: str) -> dict:
"""Starts a desktop mirroring session on a specific receiver target.
Args:
sink_name: Name of the sink to use as the target.
"""
return self.execute("startDesktopMirroring", {"sinkName": sink_name})
def start_tab_mirroring(self, sink_name: str) -> dict:
"""Starts a tab mirroring session on a specific receiver target.
Args:
sink_name: Name of the sink to use as the target.
"""
return self.execute("startTabMirroring", {"sinkName": sink_name})
def stop_casting(self, sink_name: str) -> dict:
"""Stops the existing Cast session on a specific receiver target.
Args:
sink_name: Name of the sink to stop the Cast session.
"""
return self.execute("stopCasting", {"sinkName": sink_name})
@@ -0,0 +1,16 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@@ -0,0 +1,379 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The ActionChains implementation."""
from __future__ import annotations
from typing import TYPE_CHECKING
from selenium.webdriver.common.actions.action_builder import ActionBuilder
from selenium.webdriver.common.actions.key_input import KeyInput
from selenium.webdriver.common.actions.pointer_input import PointerInput
from selenium.webdriver.common.actions.wheel_input import ScrollOrigin, WheelInput
from selenium.webdriver.common.utils import keys_to_typing
from selenium.webdriver.remote.webelement import WebElement
if TYPE_CHECKING:
from selenium.webdriver.remote.webdriver import WebDriver
class ActionChains:
"""Automate low-level interactions like mouse movements, button actions, key presses, and context menus.
ActionChains are a way to automate low level interactions such as mouse
movements, mouse button actions, key press, and context menu interactions.
This is useful for doing more complex actions like hover over and drag and
drop.
Generate user actions.
When you call methods for actions on the ActionChains object,
the actions are stored in a queue in the ActionChains object.
When you call perform(), the events are fired in the order they
are queued up.
ActionChains can be used in a chain pattern::
menu = driver.find_element(By.CSS_SELECTOR, ".nav")
hidden_submenu = driver.find_element(By.CSS_SELECTOR, ".nav #submenu1")
ActionChains(driver).move_to_element(menu).click(hidden_submenu).perform()
Or actions can be queued up one by one, then performed.::
menu = driver.find_element(By.CSS_SELECTOR, ".nav")
hidden_submenu = driver.find_element(By.CSS_SELECTOR, ".nav #submenu1")
actions = ActionChains(driver)
actions.move_to_element(menu)
actions.click(hidden_submenu)
actions.perform()
Either way, the actions are performed in the order they are called, one after
another.
"""
def __init__(
self,
driver: WebDriver,
duration: int = 250,
devices: list[PointerInput | KeyInput | WheelInput] | None = None,
) -> None:
"""Creates a new ActionChains.
Args:
driver: The WebDriver instance which performs user actions.
duration: override the default 250 msecs of DEFAULT_MOVE_DURATION in PointerInput
devices: Optional list of input devices (PointerInput, KeyInput, WheelInput) to use.
If not provided, default devices will be created.
"""
self._driver = driver
mouse = None
keyboard = None
wheel = None
if devices is not None and isinstance(devices, list):
for device in devices:
if isinstance(device, PointerInput):
mouse = device
if isinstance(device, KeyInput):
keyboard = device
if isinstance(device, WheelInput):
wheel = device
self.w3c_actions = ActionBuilder(driver, mouse=mouse, keyboard=keyboard, wheel=wheel, duration=duration)
def perform(self) -> None:
"""Performs all stored actions."""
self.w3c_actions.perform()
def reset_actions(self) -> None:
"""Clear actions stored locally and on the remote end."""
self.w3c_actions.clear_actions()
for device in self.w3c_actions.devices:
device.clear_actions()
def click(self, on_element: WebElement | None = None) -> ActionChains:
"""Clicks an element.
Args:
on_element: The element to click.
If None, clicks on current mouse position.
"""
if on_element:
self.move_to_element(on_element)
self.w3c_actions.pointer_action.click()
self.w3c_actions.key_action.pause()
self.w3c_actions.key_action.pause()
return self
def click_and_hold(self, on_element: WebElement | None = None) -> ActionChains:
"""Holds down the left mouse button on an element.
Args:
on_element: The element to mouse down.
If None, clicks on current mouse position.
"""
if on_element:
self.move_to_element(on_element)
self.w3c_actions.pointer_action.click_and_hold()
self.w3c_actions.key_action.pause()
return self
def context_click(self, on_element: WebElement | None = None) -> ActionChains:
"""Performs a context-click (right click) on an element.
Args:
on_element: The element to context-click.
If None, clicks on current mouse position.
"""
if on_element:
self.move_to_element(on_element)
self.w3c_actions.pointer_action.context_click()
self.w3c_actions.key_action.pause()
self.w3c_actions.key_action.pause()
return self
def double_click(self, on_element: WebElement | None = None) -> ActionChains:
"""Double-clicks an element.
Args:
on_element: The element to double-click.
If None, clicks on current mouse position.
"""
if on_element:
self.move_to_element(on_element)
self.w3c_actions.pointer_action.double_click()
for _ in range(4):
self.w3c_actions.key_action.pause()
return self
def drag_and_drop(self, source: WebElement, target: WebElement) -> ActionChains:
"""Hold down the left mouse button on an element, then move to target and release.
Args:
source: The element to mouse down.
target: The element to mouse up.
"""
self.click_and_hold(source)
self.release(target)
return self
def drag_and_drop_by_offset(self, source: WebElement, xoffset: int, yoffset: int) -> ActionChains:
"""Hold down the left mouse button on an element, then move by offset and release.
Args:
source: The element to mouse down.
xoffset: X offset to move to.
yoffset: Y offset to move to.
"""
self.click_and_hold(source)
self.move_by_offset(xoffset, yoffset)
self.release()
return self
def key_down(self, value: str, element: WebElement | None = None) -> ActionChains:
"""Send a key press only without releasing it (modifier keys only).
Args:
value: The modifier key to send. Values are defined in `Keys` class.
element: The element to send keys.
If None, sends a key to current focused element.
Example, pressing ctrl+c::
ActionChains(driver).key_down(Keys.CONTROL).send_keys("c").key_up(Keys.CONTROL).perform()
"""
if element:
self.click(element)
self.w3c_actions.key_action.key_down(value)
self.w3c_actions.pointer_action.pause()
return self
def key_up(self, value: str, element: WebElement | None = None) -> ActionChains:
"""Releases a modifier key.
Args:
value: The modifier key to send. Values are defined in Keys class.
element: The element to send keys.
If None, sends a key to current focused element.
Example, pressing ctrl+c::
ActionChains(driver).key_down(Keys.CONTROL).send_keys("c").key_up(Keys.CONTROL).perform()
"""
if element:
self.click(element)
self.w3c_actions.key_action.key_up(value)
self.w3c_actions.pointer_action.pause()
return self
def move_by_offset(self, xoffset: int, yoffset: int) -> ActionChains:
"""Moving the mouse to an offset from current mouse position.
Args:
xoffset: X offset to move to, as a positive or negative integer.
yoffset: Y offset to move to, as a positive or negative integer.
"""
self.w3c_actions.pointer_action.move_by(xoffset, yoffset)
self.w3c_actions.key_action.pause()
return self
def move_to_element(self, to_element: WebElement) -> ActionChains:
"""Moving the mouse to the middle of an element.
Args:
to_element: The WebElement to move to.
"""
self.w3c_actions.pointer_action.move_to(to_element)
self.w3c_actions.key_action.pause()
return self
def move_to_element_with_offset(self, to_element: WebElement, xoffset: int, yoffset: int) -> ActionChains:
"""Move the mouse to an element with the specified offsets.
Offsets are relative to the in-view center point of the element.
Args:
to_element: The WebElement to move to.
xoffset: X offset to move to, as a positive or negative integer.
yoffset: Y offset to move to, as a positive or negative integer.
"""
self.w3c_actions.pointer_action.move_to(to_element, int(xoffset), int(yoffset))
self.w3c_actions.key_action.pause()
return self
def pause(self, seconds: float | int) -> ActionChains:
"""Pause all inputs for the specified duration in seconds."""
self.w3c_actions.pointer_action.pause(seconds)
self.w3c_actions.key_action.pause(int(seconds))
return self
def release(self, on_element: WebElement | None = None) -> ActionChains:
"""Releasing a held mouse button on an element.
Args:
on_element: The element to mouse up.
If None, releases on current mouse position.
"""
if on_element:
self.move_to_element(on_element)
self.w3c_actions.pointer_action.release()
self.w3c_actions.key_action.pause()
return self
def send_keys(self, *keys_to_send: str) -> ActionChains:
"""Sends keys to current focused element.
Args:
keys_to_send: The keys to send. Modifier keys constants can be found in the
'Keys' class.
"""
typing = keys_to_typing(keys_to_send)
for key in typing:
self.key_down(key)
self.key_up(key)
return self
def send_keys_to_element(self, element: WebElement, *keys_to_send: str) -> ActionChains:
"""Sends keys to an element.
Args:
element: The element to send keys.
keys_to_send: The keys to send. Modifier keys constants can be found in the
'Keys' class.
"""
self.click(element)
self.send_keys(*keys_to_send)
return self
def scroll_to_element(self, element: WebElement) -> ActionChains:
"""Scroll the element into the viewport if it's outside it.
Scrolls the bottom of the element to the bottom of the viewport.
Args:
element: Which element to scroll into the viewport.
"""
self.w3c_actions.wheel_action.scroll(origin=element)
return self
def scroll_by_amount(self, delta_x: int, delta_y: int) -> ActionChains:
"""Scroll by a provided amount with the origin in the top left corner.
Scrolls by provided amounts with the origin in the top left corner
of the viewport.
Args:
delta_x: Distance along X axis to scroll using the wheel. A negative value scrolls left.
delta_y: Distance along Y axis to scroll using the wheel. A negative value scrolls up.
"""
self.w3c_actions.wheel_action.scroll(delta_x=delta_x, delta_y=delta_y)
return self
def scroll_from_origin(self, scroll_origin: ScrollOrigin, delta_x: int, delta_y: int) -> ActionChains:
"""Scroll by a provided amount based on a scroll origin (element or viewport).
The scroll origin is either the center of an element or the upper left of the
viewport plus any offsets. If the origin is an element, and the element
is not in the viewport, the bottom of the element will first be
scrolled to the bottom of the viewport.
Args:
scroll_origin: Where scroll originates (viewport or element center) plus provided offsets.
delta_x: Distance along X axis to scroll using the wheel. A negative value scrolls left.
delta_y: Distance along Y axis to scroll using the wheel. A negative value scrolls up.
Raises:
MoveTargetOutOfBoundsException: If the origin with offset is outside the viewport.
"""
if not isinstance(scroll_origin, ScrollOrigin):
raise TypeError(f"Expected object of type ScrollOrigin, got: {type(scroll_origin)}")
self.w3c_actions.wheel_action.scroll(
origin=scroll_origin.origin,
x=scroll_origin.x_offset,
y=scroll_origin.y_offset,
delta_x=delta_x,
delta_y=delta_y,
)
return self
# Context manager so ActionChains can be used in a 'with .. as' statements.
def __enter__(self) -> ActionChains:
return self # Return created instance of self.
def __exit__(self, _type, _value, _traceback) -> None:
pass # Do nothing, does not require additional cleanup.
@@ -0,0 +1,16 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@@ -0,0 +1,168 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Union
from selenium.webdriver.common.actions import interaction
from selenium.webdriver.common.actions.key_actions import KeyActions
from selenium.webdriver.common.actions.key_input import KeyInput
from selenium.webdriver.common.actions.pointer_actions import PointerActions
from selenium.webdriver.common.actions.pointer_input import PointerInput
from selenium.webdriver.common.actions.wheel_actions import WheelActions
from selenium.webdriver.common.actions.wheel_input import WheelInput
from selenium.webdriver.remote.command import Command
class ActionBuilder:
def __init__(
self,
driver,
mouse: PointerInput | None = None,
wheel: WheelInput | None = None,
keyboard: KeyInput | None = None,
duration: int = 250,
) -> None:
mouse = mouse or PointerInput(interaction.POINTER_MOUSE, "mouse")
keyboard = keyboard or KeyInput(interaction.KEY)
wheel = wheel or WheelInput(interaction.WHEEL)
self.devices: list[PointerInput | KeyInput | WheelInput] = [mouse, keyboard, wheel]
self._key_action = KeyActions(keyboard)
self._pointer_action = PointerActions(mouse, duration=duration)
self._wheel_action = WheelActions(wheel)
self.driver = driver
def get_device_with(self, name: str) -> Union["WheelInput", "PointerInput", "KeyInput"] | None:
"""Get the device with the given name.
Args:
name: The name of the device to get.
Returns:
The device with the given name, or None if not found.
"""
return next(filter(lambda x: x == name, self.devices), None)
@property
def pointer_inputs(self) -> list[PointerInput]:
return [device for device in self.devices if isinstance(device, PointerInput)]
@property
def key_inputs(self) -> list[KeyInput]:
return [device for device in self.devices if isinstance(device, KeyInput)]
@property
def key_action(self) -> KeyActions:
return self._key_action
@property
def pointer_action(self) -> PointerActions:
return self._pointer_action
@property
def wheel_action(self) -> WheelActions:
return self._wheel_action
def add_key_input(self, name: str) -> KeyInput:
"""Add a new key input device to the action builder.
Args:
name: The name of the key input device.
Returns:
The newly created key input device.
Example:
>>> action_builder = ActionBuilder(driver)
>>> action_builder.add_key_input(name="keyboard2")
"""
new_input = KeyInput(name)
self._add_input(new_input)
return new_input
def add_pointer_input(self, kind: str, name: str) -> PointerInput:
"""Add a new pointer input device to the action builder.
Args:
kind: The kind of pointer input device. Valid values are "mouse",
"touch", or "pen".
name: The name of the pointer input device.
Returns:
The newly created pointer input device.
Example:
>>> action_builder = ActionBuilder(driver)
>>> action_builder.add_pointer_input(kind="mouse", name="mouse")
"""
new_input = PointerInput(kind, name)
self._add_input(new_input)
return new_input
def add_wheel_input(self, name: str) -> WheelInput:
"""Add a new wheel input device to the action builder.
Args:
name: The name of the wheel input device.
Returns:
The newly created wheel input device.
Example:
>>> action_builder = ActionBuilder(driver)
>>> action_builder.add_wheel_input(name="wheel2")
"""
new_input = WheelInput(name)
self._add_input(new_input)
return new_input
def perform(self) -> None:
"""Performs all stored actions.
Example:
>>> action_builder = ActionBuilder(driver)
>>> keyboard = action_builder.key_input
>>> el = driver.find_element(id: "some_id")
>>> action_builder.click(el).pause(keyboard).pause(keyboard).pause(keyboard).send_keys("keys").perform()
"""
enc: dict[str, list[Any]] = {"actions": []}
for device in self.devices:
encoded = device.encode()
if encoded["actions"]:
enc["actions"].append(encoded)
device.actions = []
self.driver.execute(Command.W3C_ACTIONS, enc)
def clear_actions(self) -> None:
"""Clears actions that are already stored on the remote end.
Example:
>>> action_builder = ActionBuilder(driver)
>>> keyboard = action_builder.key_input
>>> el = driver.find_element(By.ID, "some_id")
>>> action_builder.click(el).pause(keyboard).pause(keyboard).pause(keyboard).send_keys("keys")
>>> action_builder.clear_actions()
"""
self.driver.execute(Command.W3C_CLEAR_ACTIONS)
def _add_input(self, new_input: KeyInput | PointerInput | WheelInput) -> None:
"""Add a new input device to the action builder.
Args:
new_input: The new input device to add.
"""
self.devices.append(new_input)
@@ -0,0 +1,36 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import uuid
from typing import Any
class InputDevice:
"""Describes the input device being used for the action."""
def __init__(self, name: str | None = None):
self.name = name or uuid.uuid4()
self.actions: list[Any] = []
def add_action(self, action: Any) -> None:
self.actions.append(action)
def clear_actions(self) -> None:
self.actions = []
def create_pause(self, duration: float = 0) -> None:
pass
@@ -0,0 +1,46 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.actions.input_device import InputDevice
KEY = "key"
POINTER = "pointer"
NONE = "none"
WHEEL = "wheel"
SOURCE_TYPES = {KEY, POINTER, WHEEL, NONE}
POINTER_MOUSE = "mouse"
POINTER_TOUCH = "touch"
POINTER_PEN = "pen"
POINTER_KINDS = {POINTER_MOUSE, POINTER_TOUCH, POINTER_PEN}
class Interaction:
PAUSE = "pause"
def __init__(self, source: InputDevice) -> None:
self.source = source
class Pause(Interaction):
def __init__(self, source, duration: float = 0) -> None:
super().__init__(source)
self.duration = duration
def encode(self) -> dict[str, str | int]:
return {"type": self.PAUSE, "duration": int(self.duration * 1000)}
@@ -0,0 +1,54 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from selenium.webdriver.common.actions.interaction import KEY, Interaction
from selenium.webdriver.common.actions.key_input import KeyInput
from selenium.webdriver.common.actions.pointer_input import PointerInput
from selenium.webdriver.common.actions.wheel_input import WheelInput
from selenium.webdriver.common.utils import keys_to_typing
class KeyActions(Interaction):
def __init__(self, source: KeyInput | PointerInput | WheelInput | None = None) -> None:
if source is None:
source = KeyInput(KEY)
self.input_source = source
super().__init__(source)
def key_down(self, letter: str) -> KeyActions:
return self._key_action("create_key_down", letter)
def key_up(self, letter: str) -> KeyActions:
return self._key_action("create_key_up", letter)
def pause(self, duration: int = 0) -> KeyActions:
return self._key_action("create_pause", duration)
def send_keys(self, text: str | list) -> KeyActions:
if not isinstance(text, list):
text = keys_to_typing(text)
for letter in text:
self.key_down(letter)
self.key_up(letter)
return self
def _key_action(self, action: str, letter) -> KeyActions:
meth = getattr(self.source, action)
meth(letter)
return self
@@ -0,0 +1,48 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.actions import interaction
from selenium.webdriver.common.actions.input_device import InputDevice
from selenium.webdriver.common.actions.interaction import Interaction, Pause
class KeyInput(InputDevice):
def __init__(self, name: str) -> None:
super().__init__()
self.name = name
self.type = interaction.KEY
def encode(self) -> dict:
return {"type": self.type, "id": self.name, "actions": [acts.encode() for acts in self.actions]}
def create_key_down(self, key) -> None:
self.add_action(TypingInteraction(self, "keyDown", key))
def create_key_up(self, key) -> None:
self.add_action(TypingInteraction(self, "keyUp", key))
def create_pause(self, pause_duration: float = 0) -> None:
self.add_action(Pause(self, pause_duration))
class TypingInteraction(Interaction):
def __init__(self, source, type_, key) -> None:
super().__init__(source)
self.type = type_
self.key = key
def encode(self) -> dict:
return {"type": self.type, "value": self.key}
@@ -0,0 +1,24 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
class MouseButton:
LEFT = 0
MIDDLE = 1
RIGHT = 2
BACK = 3
FORWARD = 4
@@ -0,0 +1,206 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.actions import interaction
from selenium.webdriver.common.actions.interaction import Interaction
from selenium.webdriver.common.actions.mouse_button import MouseButton
from selenium.webdriver.common.actions.pointer_input import PointerInput
from selenium.webdriver.remote.webelement import WebElement
class PointerActions(Interaction):
def __init__(self, source: PointerInput | None = None, duration: int = 250):
"""Initialize a new PointerActions instance.
Args:
source: Optional PointerInput instance. If not provided, a default
mouse PointerInput will be created.
duration: Override the default 250 msecs of DEFAULT_MOVE_DURATION
in the source.
"""
if source is None:
source = PointerInput(interaction.POINTER_MOUSE, "mouse")
self.source = source
self._duration = duration
super().__init__(source)
def pointer_down(
self,
button=MouseButton.LEFT,
width=None,
height=None,
pressure=None,
tangential_pressure=None,
tilt_x=None,
tilt_y=None,
twist=None,
altitude_angle=None,
azimuth_angle=None,
):
self._button_action(
"create_pointer_down",
button=button,
width=width,
height=height,
pressure=pressure,
tangential_pressure=tangential_pressure,
tilt_x=tilt_x,
tilt_y=tilt_y,
twist=twist,
altitude_angle=altitude_angle,
azimuth_angle=azimuth_angle,
)
return self
def pointer_up(self, button=MouseButton.LEFT):
self._button_action("create_pointer_up", button=button)
return self
def move_to(
self,
element,
x=0,
y=0,
width=None,
height=None,
pressure=None,
tangential_pressure=None,
tilt_x=None,
tilt_y=None,
twist=None,
altitude_angle=None,
azimuth_angle=None,
):
if not isinstance(element, WebElement):
raise AttributeError("move_to requires a WebElement")
self.source.create_pointer_move(
origin=element,
duration=self._duration,
x=int(x),
y=int(y),
width=width,
height=height,
pressure=pressure,
tangential_pressure=tangential_pressure,
tilt_x=tilt_x,
tilt_y=tilt_y,
twist=twist,
altitude_angle=altitude_angle,
azimuth_angle=azimuth_angle,
)
return self
def move_by(
self,
x,
y,
width=None,
height=None,
pressure=None,
tangential_pressure=None,
tilt_x=None,
tilt_y=None,
twist=None,
altitude_angle=None,
azimuth_angle=None,
):
self.source.create_pointer_move(
origin=interaction.POINTER,
duration=self._duration,
x=int(x),
y=int(y),
width=width,
height=height,
pressure=pressure,
tangential_pressure=tangential_pressure,
tilt_x=tilt_x,
tilt_y=tilt_y,
twist=twist,
altitude_angle=altitude_angle,
azimuth_angle=azimuth_angle,
)
return self
def move_to_location(
self,
x,
y,
width=None,
height=None,
pressure=None,
tangential_pressure=None,
tilt_x=None,
tilt_y=None,
twist=None,
altitude_angle=None,
azimuth_angle=None,
):
self.source.create_pointer_move(
origin="viewport",
duration=self._duration,
x=int(x),
y=int(y),
width=width,
height=height,
pressure=pressure,
tangential_pressure=tangential_pressure,
tilt_x=tilt_x,
tilt_y=tilt_y,
twist=twist,
altitude_angle=altitude_angle,
azimuth_angle=azimuth_angle,
)
return self
def click(self, element: WebElement | None = None, button=MouseButton.LEFT):
if element:
self.move_to(element)
self.pointer_down(button)
self.pointer_up(button)
return self
def context_click(self, element: WebElement | None = None):
return self.click(element=element, button=MouseButton.RIGHT)
def click_and_hold(self, element: WebElement | None = None, button=MouseButton.LEFT):
if element:
self.move_to(element)
self.pointer_down(button=button)
return self
def release(self, button=MouseButton.LEFT):
self.pointer_up(button=button)
return self
def double_click(self, element: WebElement | None = None):
if element:
self.move_to(element)
self.pointer_down(MouseButton.LEFT)
self.pointer_up(MouseButton.LEFT)
self.pointer_down(MouseButton.LEFT)
self.pointer_up(MouseButton.LEFT)
return self
def pause(self, duration: float = 0):
self.source.create_pause(duration)
return self
def _button_action(self, action, **kwargs):
meth = getattr(self.source, action)
meth(**kwargs)
return self
@@ -0,0 +1,79 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from selenium.common.exceptions import InvalidArgumentException
from selenium.webdriver.common.actions.input_device import InputDevice
from selenium.webdriver.common.actions.interaction import POINTER, POINTER_KINDS
from selenium.webdriver.remote.webelement import WebElement
class PointerInput(InputDevice):
DEFAULT_MOVE_DURATION = 250
def __init__(self, kind, name):
super().__init__()
if kind not in POINTER_KINDS:
raise InvalidArgumentException(f"Invalid PointerInput kind '{kind}'")
self.type = POINTER
self.kind = kind
self.name = name
def create_pointer_move(
self,
duration=DEFAULT_MOVE_DURATION,
x: float = 0,
y: float = 0,
origin: WebElement | None = None,
**kwargs,
):
action = {"type": "pointerMove", "duration": duration, "x": x, "y": y, **kwargs}
if isinstance(origin, WebElement):
action["origin"] = {"element-6066-11e4-a52e-4f735466cecf": origin.id}
elif origin is not None:
action["origin"] = origin
self.add_action(self._convert_keys(action))
def create_pointer_down(self, **kwargs):
data = {"type": "pointerDown", "duration": 0, **kwargs}
self.add_action(self._convert_keys(data))
def create_pointer_up(self, button):
self.add_action({"type": "pointerUp", "duration": 0, "button": button})
def create_pointer_cancel(self):
self.add_action({"type": "pointerCancel"})
def create_pause(self, pause_duration: int | float = 0) -> None:
self.add_action({"type": "pause", "duration": int(pause_duration * 1000)})
def encode(self):
return {"type": self.type, "parameters": {"pointerType": self.kind}, "id": self.name, "actions": self.actions}
def _convert_keys(self, actions: dict[str, Any]):
out = {}
for k, v in actions.items():
if v is None:
continue
if k in ("x", "y"):
out[k] = int(v)
continue
splits = k.split("_")
new_key = splits[0] + "".join(v.title() for v in splits[1:])
out[new_key] = v
return out
@@ -0,0 +1,35 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.actions.interaction import WHEEL, Interaction
from selenium.webdriver.common.actions.wheel_input import WheelInput
class WheelActions(Interaction):
def __init__(self, source: WheelInput | None = None):
if source is None:
source = WheelInput(WHEEL)
super().__init__(source)
def pause(self, duration: float = 0):
self.source.create_pause(duration)
return self
def scroll(self, x=0, y=0, delta_x=0, delta_y=0, duration=0, origin="viewport"):
self.source.create_scroll(x, y, delta_x, delta_y, duration, origin)
return self
@@ -0,0 +1,75 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.actions import interaction
from selenium.webdriver.common.actions.input_device import InputDevice
from selenium.webdriver.remote.webelement import WebElement
class ScrollOrigin:
def __init__(self, origin: str | WebElement, x_offset: int, y_offset: int) -> None:
self._origin = origin
self._x_offset = x_offset
self._y_offset = y_offset
@classmethod
def from_element(cls, element: WebElement, x_offset: int = 0, y_offset: int = 0):
return cls(element, x_offset, y_offset)
@classmethod
def from_viewport(cls, x_offset: int = 0, y_offset: int = 0):
return cls("viewport", x_offset, y_offset)
@property
def origin(self) -> str | WebElement:
return self._origin
@property
def x_offset(self) -> int:
return self._x_offset
@property
def y_offset(self) -> int:
return self._y_offset
class WheelInput(InputDevice):
def __init__(self, name) -> None:
super().__init__(name=name)
self.name = name
self.type = interaction.WHEEL
def encode(self) -> dict:
return {"type": self.type, "id": self.name, "actions": self.actions}
def create_scroll(self, x: int, y: int, delta_x: int, delta_y: int, duration: int, origin) -> None:
if isinstance(origin, WebElement):
origin = {"element-6066-11e4-a52e-4f735466cecf": origin.id}
self.add_action(
{
"type": "scroll",
"x": x,
"y": y,
"deltaX": delta_x,
"deltaY": delta_y,
"duration": duration,
"origin": origin,
}
)
def create_pause(self, pause_duration: int | float = 0) -> None:
self.add_action({"type": "pause", "duration": int(pause_duration * 1000)})
@@ -0,0 +1,78 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The Alert implementation."""
from selenium.webdriver.common.utils import keys_to_typing
from selenium.webdriver.remote.command import Command
class Alert:
"""Allows to work with alerts.
Use this class to interact with alert prompts. It contains methods for dismissing,
accepting, inputting, and getting text from alert prompts.
Accepting / Dismissing alert prompts::
Alert(driver).accept()
Alert(driver).dismiss()
Inputting a value into an alert prompt::
name_prompt = Alert(driver)
name_prompt.send_keys("Willian Shakesphere")
name_prompt.accept()
Reading a the text of a prompt for verification::
alert_text = Alert(driver).text
self.assertEqual("Do you wish to quit?", alert_text)
"""
def __init__(self, driver) -> None:
"""Creates a new Alert.
Args:
driver: The WebDriver instance which performs user actions.
"""
self.driver = driver
@property
def text(self) -> str:
"""Gets the text of the Alert."""
return self.driver.execute(Command.W3C_GET_ALERT_TEXT)["value"]
def dismiss(self) -> None:
"""Dismisses the alert available."""
self.driver.execute(Command.W3C_DISMISS_ALERT)
def accept(self) -> None:
"""Accepts the alert available.
Example:
Alert(driver).accept() # Confirm a alert dialog.
"""
self.driver.execute(Command.W3C_ACCEPT_ALERT)
def send_keys(self, keysToSend: str) -> None:
"""Send Keys to the Alert.
Args:
keysToSend: The text to be sent to Alert.
"""
self.driver.execute(Command.W3C_SET_ALERT_VALUE, {"value": keys_to_typing(keysToSend), "text": keysToSend})
@@ -0,0 +1,704 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""APIRequestContext for making HTTP requests with browser cookie synchronization."""
import json
import logging
import pathlib
import time
import urllib.parse
from email.utils import parsedate_to_datetime
from http.client import responses as http_status_phrases
from typing import TYPE_CHECKING, Any
import urllib3
from urllib3.util.retry import Retry
if TYPE_CHECKING:
from selenium.webdriver.remote.webdriver import WebDriver
logger = logging.getLogger(__name__)
class APIRequestFailure(Exception):
"""Raised when an API request returns a non-2xx status and fail_on_status_code is True.
Attributes:
response: The APIResponse that triggered the failure.
"""
def __init__(self, response: "APIResponse") -> None:
self.response = response
super().__init__(f"{response.status} {response.status_text}: {response.url}")
class APIResponse:
"""Represents an HTTP response from an API request.
Attributes:
status: HTTP status code.
status_text: HTTP status text.
headers: Response headers as a dict.
url: The request URL.
"""
def __init__(self, status: int, status_text: str, headers: dict[str, str], url: str, body: bytes) -> None:
self.status = status
self.status_text = status_text
self.headers = headers
self.url = url
self._body = body
@property
def ok(self) -> bool:
"""Whether the response status is in the 200-299 range."""
return 200 <= self.status <= 299
def json(self) -> Any:
"""Parse the response body as JSON.
Returns:
The parsed JSON object.
"""
return json.loads(self._body)
def text(self) -> str:
"""Decode the response body as UTF-8 text.
Returns:
The response body as a string.
"""
return self._body.decode("utf-8")
def body(self) -> bytes:
"""Return the raw response body bytes.
Returns:
The response body as bytes.
"""
return self._body
def dispose(self) -> None:
"""Free the response body memory."""
self._body = b""
def _cookie_matches(cookie: dict, url: str, default_domain: str = "") -> bool:
"""Check if a browser cookie should be sent with a request to the given URL.
Evaluates expiry, domain, path, and secure attribute matching per RFC 6265.
Args:
cookie: A cookie dict from driver.get_cookies().
url: The target request URL.
default_domain: Fallback domain for host-only cookies (no domain attribute).
When a cookie has no domain, it only matches if the request hostname
equals this value. If empty and cookie has no domain, the cookie is skipped.
Returns:
True if the cookie matches the URL.
"""
# Expiry check — skip expired cookies
expiry = cookie.get("expiry")
if expiry is not None and expiry <= int(time.time()):
return False
parsed = urllib.parse.urlparse(url)
hostname = parsed.hostname or ""
path = parsed.path or "/"
scheme = parsed.scheme or "http"
# Domain matching (RFC 6265 section 5.1.3)
cookie_domain = cookie.get("domain", "")
if not cookie_domain:
# Host-only cookie — must match the origin host exactly
if not default_domain or hostname != default_domain:
return False
elif cookie_domain.startswith("."):
# .example.com matches example.com and sub.example.com
if not (hostname == cookie_domain[1:] or hostname.endswith(cookie_domain)):
return False
else:
if hostname != cookie_domain:
return False
# Path matching (RFC 6265 section 5.1.4)
cookie_path = cookie.get("path", "/")
if cookie_path == "/":
pass # root path matches everything
elif path != cookie_path and not path.startswith(cookie_path + "/"):
return False
# Secure matching
if cookie.get("secure", False) and scheme != "https":
return False
return True
def _parse_set_cookie(header_value: str) -> dict:
"""Parse a single Set-Cookie header value into a cookie dict.
Uses manual parsing instead of http.cookies.SimpleCookie which is too
strict for real-world Set-Cookie headers.
Args:
header_value: The Set-Cookie header string.
Returns:
A dict with cookie attributes suitable for driver.add_cookie().
"""
parts = header_value.split(";")
name_value = parts[0].strip()
eq_idx = name_value.find("=")
if eq_idx == -1:
return {}
name = name_value[:eq_idx].strip()
value = name_value[eq_idx + 1 :].strip()
cookie: dict[str, Any] = {"name": name, "value": value}
has_max_age = False
for part in parts[1:]:
part = part.strip()
if not part:
continue
if "=" in part:
attr_name, attr_value = part.split("=", 1)
attr_name = attr_name.strip().lower()
attr_value = attr_value.strip()
else:
attr_name = part.strip().lower()
attr_value = ""
if attr_name == "domain":
cookie["domain"] = attr_value
elif attr_name == "path":
cookie["path"] = attr_value
elif attr_name == "secure":
cookie["secure"] = True
elif attr_name == "httponly":
cookie["httpOnly"] = True
elif attr_name == "samesite":
cookie["sameSite"] = attr_value
elif attr_name == "max-age":
try:
max_age = int(attr_value)
cookie["expiry"] = int(time.time()) + max_age
has_max_age = True
except ValueError:
pass
elif attr_name == "expires" and not has_max_age:
# RFC 6265 §5.3: Max-Age takes precedence over Expires
try:
dt = parsedate_to_datetime(attr_value)
cookie["expiry"] = int(dt.timestamp())
except (ValueError, TypeError):
pass
return cookie
def _get_set_cookie_headers(resp: urllib3.BaseHTTPResponse) -> list[str]:
"""Extract all Set-Cookie header values from a urllib3 response.
Args:
resp: The urllib3 HTTP response.
Returns:
A list of Set-Cookie header strings.
"""
if hasattr(resp.headers, "getlist"):
headers = resp.headers.getlist("Set-Cookie")
if headers:
return headers
sc = resp.headers.get("Set-Cookie")
return [sc] if sc else []
def _resolve_redirect_url(resp: urllib3.BaseHTTPResponse, original_url: str) -> str:
"""Return the final URL after any redirects.
urllib3's retry history records each hop. When redirects occurred,
the last entry's redirect_location resolved against its URL gives
the final destination. When no redirects occurred, the original
request URL is returned unchanged.
"""
history = resp.retries.history if resp.retries else ()
if history:
last = history[-1]
if last.url and last.redirect_location:
return urllib.parse.urljoin(last.url, last.redirect_location)
return original_url
class _BaseRequestContext:
"""Base class with shared HTTP request logic for API request contexts."""
def __init__(
self,
base_url: str = "",
extra_headers: dict[str, str] | None = None,
timeout: float = 30.0,
max_redirects: int = 10,
fail_on_status_code: bool = False,
) -> None:
self._base_url = base_url
self._extra_headers = extra_headers or {}
self._timeout = timeout
self._max_redirects = max_redirects
self._fail_on_status_code = fail_on_status_code
self._pool = urllib3.PoolManager()
def get(self, url: str, **kwargs: Any) -> APIResponse:
"""Send a GET request.
Args:
url: The request URL (absolute or relative to base_url).
**kwargs: Optional arguments: headers, params, timeout, max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, "GET", **kwargs)
def post(self, url: str, **kwargs: Any) -> APIResponse:
"""Send a POST request.
Args:
url: The request URL (absolute or relative to base_url).
**kwargs: Optional arguments: headers, params, data, form,
json_data, timeout, max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, "POST", **kwargs)
def put(self, url: str, **kwargs: Any) -> APIResponse:
"""Send a PUT request.
Args:
url: The request URL (absolute or relative to base_url).
**kwargs: Optional arguments: headers, params, data, form,
json_data, timeout, max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, "PUT", **kwargs)
def patch(self, url: str, **kwargs: Any) -> APIResponse:
"""Send a PATCH request.
Args:
url: The request URL (absolute or relative to base_url).
**kwargs: Optional arguments: headers, params, data, form,
json_data, timeout, max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, "PATCH", **kwargs)
def delete(self, url: str, **kwargs: Any) -> APIResponse:
"""Send a DELETE request.
Args:
url: The request URL (absolute or relative to base_url).
**kwargs: Optional arguments: headers, params, data, form,
json_data, timeout, max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, "DELETE", **kwargs)
def head(self, url: str, **kwargs: Any) -> APIResponse:
"""Send a HEAD request.
Args:
url: The request URL (absolute or relative to base_url).
**kwargs: Optional arguments: headers, params, timeout,
max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, "HEAD", **kwargs)
def fetch(self, url: str, method: str = "GET", **kwargs: Any) -> APIResponse:
"""Send an HTTP request with a custom method.
Args:
url: The request URL (absolute or relative to base_url).
method: The HTTP method to use.
**kwargs: Optional arguments: headers, params, data, form,
json_data, timeout, max_redirects, fail_on_status_code.
Returns:
An APIResponse object.
"""
return self._fetch(url, method, **kwargs)
def dispose(self) -> None:
"""Close the underlying connection pool."""
self._pool.clear()
def _resolve_url(self, url: str) -> str:
"""Resolve a URL, prepending base_url for relative paths."""
if not url.startswith(("http://", "https://")):
return self._base_url.rstrip("/") + "/" + url.lstrip("/")
return url
def _build_headers(self, kwargs: dict[str, Any]) -> dict[str, str]:
"""Merge extra_headers with per-request headers."""
headers = dict(self._extra_headers)
if kwargs.get("headers"):
headers.update(kwargs["headers"])
return headers
def _prepare_body(self, headers: dict[str, str], kwargs: dict[str, Any]) -> bytes | None:
"""Prepare the request body from json_data, form, or data kwargs.
Priority: json_data > form > data. Only one should be provided.
"""
json_data = kwargs.get("json_data")
form = kwargs.get("form")
data = kwargs.get("data")
if json_data is not None:
headers.setdefault("Content-Type", "application/json")
return json.dumps(json_data).encode("utf-8")
elif form is not None:
headers.setdefault("Content-Type", "application/x-www-form-urlencoded")
return urllib.parse.urlencode(form).encode("utf-8")
elif data is not None:
if isinstance(data, dict):
headers.setdefault("Content-Type", "application/x-www-form-urlencoded")
return urllib.parse.urlencode(data).encode("utf-8")
elif isinstance(data, str):
return data.encode("utf-8")
elif isinstance(data, bytes):
return data
return None
def _append_params(self, url: str, kwargs: dict[str, Any]) -> str:
"""Append query parameters to the URL."""
params = kwargs.get("params")
if params:
separator = "&" if "?" in url else "?"
return url + separator + urllib.parse.urlencode(params)
return url
def _execute_request(
self, method: str, url: str, headers: dict[str, str], body: bytes | None, kwargs: dict[str, Any]
) -> urllib3.BaseHTTPResponse:
"""Execute the HTTP request via urllib3."""
timeout = kwargs.get("timeout", self._timeout)
max_redirects = kwargs.get("max_redirects", self._max_redirects)
follow = max_redirects > 0
retries = Retry(
connect=0,
read=0,
status=0,
other=0,
redirect=max_redirects if follow else 0,
raise_on_redirect=False,
)
return self._pool.request(
method,
url,
headers=headers,
body=body,
timeout=timeout,
redirect=follow,
retries=retries,
preload_content=True,
)
def _build_response(self, resp: urllib3.BaseHTTPResponse, url: str) -> APIResponse:
"""Build an APIResponse from a urllib3 response."""
# Merge duplicate headers per RFC 7230 §3.2.2 (combine with ", ")
resp_headers: dict[str, str] = {}
for k, v in resp.headers.items():
key = k.lower()
if key in resp_headers:
resp_headers[key] = resp_headers[key] + ", " + v
else:
resp_headers[key] = v
# urllib3 2.x removed resp.reason; fall back to stdlib phrase lookup
reason = getattr(resp, "reason", None)
status_text = reason or http_status_phrases.get(resp.status, "")
return APIResponse(
status=resp.status,
status_text=status_text,
headers=resp_headers,
url=url,
body=resp.data,
)
def _get_cookies_for_request(self, url: str) -> list[dict]:
"""Get cookies that should be sent with the request. Overridden by subclasses."""
return []
def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None:
"""Process Set-Cookie headers from the response. Overridden by subclasses."""
def _fetch(self, url: str, method: str, **kwargs: Any) -> APIResponse:
"""Execute an HTTP request with cookie handling.
Args:
url: The request URL.
method: The HTTP method.
**kwargs: Optional arguments.
Returns:
An APIResponse object.
"""
url = self._resolve_url(url)
headers = self._build_headers(kwargs)
# Apply cookies
matching_cookies = self._get_cookies_for_request(url)
if matching_cookies:
cookie_header = "; ".join(f"{c['name']}={c['value']}" for c in matching_cookies)
if "Cookie" in headers:
headers["Cookie"] = headers["Cookie"] + "; " + cookie_header
else:
headers["Cookie"] = cookie_header
body = self._prepare_body(headers, kwargs)
url = self._append_params(url, kwargs)
resp = self._execute_request(method, url, headers, body, kwargs)
# After redirects, associate cookies with the final destination's
# origin, not the initial request URL.
final_url = _resolve_redirect_url(resp, url)
# Process response cookies
set_cookie_headers = _get_set_cookie_headers(resp)
if set_cookie_headers:
self._handle_response_cookies(set_cookie_headers, final_url)
response = self._build_response(resp, final_url)
fail = kwargs.get("fail_on_status_code", self._fail_on_status_code)
if fail and not response.ok:
raise APIRequestFailure(response)
return response
class APIRequestContext(_BaseRequestContext):
"""Makes HTTP requests with automatic browser cookie synchronization.
Cookies from the browser session are sent with API requests, and cookies
from API responses are synced back to the browser.
Args:
driver: The WebDriver instance to sync cookies with.
base_url: Optional base URL prepended to relative request paths.
extra_headers: Optional headers included in every request.
timeout: Default request timeout in seconds.
max_redirects: Maximum number of redirects to follow.
fail_on_status_code: If True, raise APIRequestFailure for non-2xx responses.
"""
def __init__(
self,
driver: "WebDriver",
base_url: str = "",
extra_headers: dict[str, str] | None = None,
timeout: float = 30.0,
max_redirects: int = 10,
fail_on_status_code: bool = False,
) -> None:
super().__init__(
base_url=base_url,
extra_headers=extra_headers,
timeout=timeout,
max_redirects=max_redirects,
fail_on_status_code=fail_on_status_code,
)
self._driver = driver
def new_context(
self,
base_url: str = "",
extra_headers: dict[str, str] | None = None,
storage_state: dict | str | pathlib.Path | None = None,
fail_on_status_code: bool = False,
) -> "_IsolatedAPIRequestContext":
"""Create an isolated API request context that does not sync with the browser.
Args:
base_url: Optional base URL for this context.
extra_headers: Optional headers for this context.
storage_state: Optional cookies to pre-load, as a dict, JSON file path, or Path.
fail_on_status_code: If True, raise APIRequestFailure for non-2xx responses.
Returns:
An _IsolatedAPIRequestContext instance.
"""
cookies: list[dict] = []
if storage_state is not None:
if isinstance(storage_state, (str, pathlib.Path)):
file_path = pathlib.Path(storage_state)
if not file_path.exists():
raise FileNotFoundError(f"Storage state file not found: {file_path}")
try:
with open(file_path) as f:
state = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in storage state file {file_path}: {e}") from e
except OSError as e:
raise OSError(f"Cannot read storage state file {file_path}: {e}") from e
else:
state = storage_state
cookies = list(state.get("cookies", []))
return _IsolatedAPIRequestContext(
base_url=base_url,
extra_headers=extra_headers,
cookies=cookies,
timeout=self._timeout,
max_redirects=self._max_redirects,
fail_on_status_code=fail_on_status_code,
)
def get_storage_state(self, path: str | pathlib.Path | None = None) -> dict[str, Any]:
"""Export the current browser cookies as a storage state dict.
Args:
path: Optional file path to save the storage state as JSON.
Returns:
A dict with a "cookies" key containing the browser cookies.
"""
cookies = self._driver.get_cookies()
state: dict[str, Any] = {"cookies": cookies}
if path is not None:
file_path = pathlib.Path(path)
try:
with open(file_path, "w") as f:
json.dump(state, f, indent=2)
except OSError as e:
raise OSError(f"Cannot write storage state to {file_path}: {e}") from e
return state
def _get_cookies_for_request(self, url: str) -> list[dict]:
"""Get matching browser cookies for the request URL."""
try:
browser_cookies = self._driver.get_cookies()
except Exception:
logger.debug("Could not retrieve browser cookies", exc_info=True)
return []
# Derive default domain from the browser's current page for host-only cookies
default_domain = ""
try:
current = self._driver.current_url
if current:
default_domain = urllib.parse.urlparse(current).hostname or ""
except Exception:
logger.debug("Could not get current URL for host-only cookie matching", exc_info=True)
return [c for c in browser_cookies if _cookie_matches(c, url, default_domain)]
def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None:
"""Sync Set-Cookie headers back to the browser."""
parsed_url = urllib.parse.urlparse(url)
for sc_header in set_cookie_headers:
cookie = _parse_set_cookie(sc_header)
if not cookie.get("name"):
continue
cookie.setdefault("domain", parsed_url.hostname or "")
cookie.setdefault("path", "/")
expiry = cookie.get("expiry")
if expiry is not None and expiry <= int(time.time()):
try:
self._driver.delete_cookie(cookie["name"])
except Exception:
pass
continue
try:
self._driver.add_cookie(cookie)
except Exception:
logger.warning(
"Could not sync cookie '%s' to browser (domain mismatch with current page)",
cookie.get("name"),
exc_info=True,
)
class _IsolatedAPIRequestContext(_BaseRequestContext):
"""An isolated API request context that maintains its own cookie jar.
Does not synchronize cookies with any browser session.
"""
def __init__(
self,
base_url: str = "",
extra_headers: dict[str, str] | None = None,
cookies: list[dict] | None = None,
timeout: float = 30.0,
max_redirects: int = 10,
fail_on_status_code: bool = False,
) -> None:
super().__init__(
base_url=base_url,
extra_headers=extra_headers,
timeout=timeout,
max_redirects=max_redirects,
fail_on_status_code=fail_on_status_code,
)
self._cookies: list[dict] = cookies or []
def get_storage_state(self) -> dict[str, Any]:
"""Return the current cookies as a storage state dict."""
return {"cookies": list(self._cookies)}
def _get_cookies_for_request(self, url: str) -> list[dict]:
"""Get matching cookies from the internal jar."""
# For isolated contexts, use the request hostname as default domain
default_domain = urllib.parse.urlparse(url).hostname or ""
return [c for c in self._cookies if _cookie_matches(c, url, default_domain)]
def _handle_response_cookies(self, set_cookie_headers: list[str], url: str) -> None:
"""Store Set-Cookie headers in the internal jar."""
parsed_url = urllib.parse.urlparse(url)
now = int(time.time())
for sc_header in set_cookie_headers:
cookie = _parse_set_cookie(sc_header)
if not cookie.get("name"):
continue
cookie.setdefault("domain", parsed_url.hostname or "")
cookie.setdefault("path", "/")
# Cookies are unique by (name, domain, path)
key = (cookie["name"], cookie.get("domain", ""), cookie.get("path", "/"))
# Remove existing cookie with same key
self._cookies = [
c for c in self._cookies if (c.get("name"), c.get("domain", ""), c.get("path", "/")) != key
]
# Only store if not expired (Max-Age=0 or negative means delete)
expiry = cookie.get("expiry")
if expiry is not None and expiry <= now:
continue
self._cookies.append(cookie)
@@ -0,0 +1,16 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@@ -0,0 +1,280 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from typing import Any
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.common.bidi.session import UserPromptHandler
from selenium.webdriver.common.proxy import Proxy
class ClientWindowState:
"""Represents a window state."""
FULLSCREEN = "fullscreen"
MAXIMIZED = "maximized"
MINIMIZED = "minimized"
NORMAL = "normal"
VALID_STATES = {FULLSCREEN, MAXIMIZED, MINIMIZED, NORMAL}
class ClientWindowInfo:
"""Represents a client window information."""
def __init__(
self,
client_window: str,
state: str,
width: int,
height: int,
x: int,
y: int,
active: bool,
):
self.client_window = client_window
self.state = state
self.width = width
self.height = height
self.x = x
self.y = y
self.active = active
def get_state(self) -> str:
"""Gets the state of the client window.
Returns:
str: The state of the client window (one of the ClientWindowState constants).
"""
return self.state
def get_client_window(self) -> str:
"""Gets the client window identifier.
Returns:
str: The client window identifier.
"""
return self.client_window
def get_width(self) -> int:
"""Gets the width of the client window.
Returns:
int: The width of the client window.
"""
return self.width
def get_height(self) -> int:
"""Gets the height of the client window.
Returns:
int: The height of the client window.
"""
return self.height
def get_x(self) -> int:
"""Gets the x coordinate of the client window.
Returns:
int: The x coordinate of the client window.
"""
return self.x
def get_y(self) -> int:
"""Gets the y coordinate of the client window.
Returns:
int: The y coordinate of the client window.
"""
return self.y
def is_active(self) -> bool:
"""Checks if the client window is active.
Returns:
bool: True if the client window is active, False otherwise.
"""
return self.active
@classmethod
def from_dict(cls, data: dict) -> "ClientWindowInfo":
"""Creates a ClientWindowInfo instance from a dictionary.
Args:
data: A dictionary containing the client window information.
Returns:
ClientWindowInfo: A new instance of ClientWindowInfo.
Raises:
ValueError: If required fields are missing or have invalid types.
"""
try:
client_window = data["clientWindow"]
if not isinstance(client_window, str):
raise ValueError("clientWindow must be a string")
state = data["state"]
if not isinstance(state, str):
raise ValueError("state must be a string")
if state not in ClientWindowState.VALID_STATES:
raise ValueError(f"Invalid state: {state}. Must be one of {ClientWindowState.VALID_STATES}")
width = data["width"]
if not isinstance(width, int) or width < 0:
raise ValueError(f"width must be a non-negative integer, got {width}")
height = data["height"]
if not isinstance(height, int) or height < 0:
raise ValueError(f"height must be a non-negative integer, got {height}")
x = data["x"]
if not isinstance(x, int):
raise ValueError(f"x must be an integer, got {type(x).__name__}")
y = data["y"]
if not isinstance(y, int):
raise ValueError(f"y must be an integer, got {type(y).__name__}")
active = data["active"]
if not isinstance(active, bool):
raise ValueError("active must be a boolean")
return cls(
client_window=client_window,
state=state,
width=width,
height=height,
x=x,
y=y,
active=active,
)
except (KeyError, TypeError) as e:
raise ValueError(f"Invalid data format for ClientWindowInfo: {e}") from e
class Browser:
"""BiDi implementation of the browser module."""
def __init__(self, conn):
self.conn = conn
def create_user_context(
self,
accept_insecure_certs: bool | None = None,
proxy: Proxy | None = None,
unhandled_prompt_behavior: UserPromptHandler | None = None,
) -> str:
"""Creates a new user context.
Args:
accept_insecure_certs: Optional flag to accept insecure TLS certificates.
proxy: Optional proxy configuration for the user context.
unhandled_prompt_behavior: Optional configuration for handling user prompts.
Returns:
str: The ID of the created user context.
"""
params: dict[str, Any] = {}
if accept_insecure_certs is not None:
params["acceptInsecureCerts"] = accept_insecure_certs
if proxy is not None:
params["proxy"] = proxy.to_bidi_dict()
if unhandled_prompt_behavior is not None:
params["unhandledPromptBehavior"] = unhandled_prompt_behavior.to_dict()
result = self.conn.execute(command_builder("browser.createUserContext", params))
return result["userContext"]
def get_user_contexts(self) -> list[str]:
"""Gets all user contexts.
Returns:
List[str]: A list of user context IDs.
"""
result = self.conn.execute(command_builder("browser.getUserContexts", {}))
return [context_info["userContext"] for context_info in result["userContexts"]]
def remove_user_context(self, user_context_id: str) -> None:
"""Removes a user context.
Args:
user_context_id: The ID of the user context to remove.
Raises:
ValueError: If the user context ID is "default" or does not exist.
"""
if user_context_id == "default":
raise ValueError("Cannot remove the default user context")
params = {"userContext": user_context_id}
self.conn.execute(command_builder("browser.removeUserContext", params))
def get_client_windows(self) -> list[ClientWindowInfo]:
"""Gets all client windows.
Returns:
List[ClientWindowInfo]: A list of client window information.
"""
result = self.conn.execute(command_builder("browser.getClientWindows", {}))
return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]]
def set_download_behavior(
self,
*,
allowed: bool | None = None,
destination_folder: str | os.PathLike | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set the download behavior for the browser or specific user contexts.
Args:
allowed: True to allow downloads, False to deny downloads, or None to
clear download behavior (revert to default).
destination_folder: Required when allowed is True. Specifies the folder
to store downloads in.
user_contexts: Optional list of user context IDs to apply this
behavior to. If omitted, updates the default behavior.
Raises:
ValueError: If allowed=True and destination_folder is missing, or if
allowed=False and destination_folder is provided.
"""
params: dict[str, Any] = {}
if allowed is None:
params["downloadBehavior"] = None
else:
if allowed:
if not destination_folder:
raise ValueError("destination_folder is required when allowed=True.")
params["downloadBehavior"] = {
"type": "allowed",
"destinationFolder": os.fspath(destination_folder),
}
else:
if destination_folder:
raise ValueError("destination_folder should not be provided when allowed=False.")
params["downloadBehavior"] = {"type": "denied"}
if user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("browser.setDownloadBehavior", params))
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,515 @@
# The MIT License(MIT)
#
# Copyright(c) 2018 Hyperion Gray
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files(the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp
import contextvars
import importlib
import itertools
import json
import logging
import pathlib
from collections import defaultdict
from collections.abc import AsyncGenerator, AsyncIterator, Generator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from typing import Any, TypeVar
import trio
from trio_websocket import ConnectionClosed as WsConnectionClosed
from trio_websocket import connect_websocket_url
logger = logging.getLogger("trio_cdp")
T = TypeVar("T")
MAX_WS_MESSAGE_SIZE = 2**24
devtools = None
version = None
def import_devtools(ver):
"""Attempt to load the current latest available devtools into the module cache for use later."""
global devtools
global version
version = ver
base = "selenium.webdriver.common.devtools.v"
try:
devtools = importlib.import_module(f"{base}{ver}")
return devtools
except ModuleNotFoundError:
# Attempt to parse and load the 'most recent' devtools module. This is likely
# because cdp has been updated but selenium python has not been released yet.
devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools")
versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir() and f.name != "latest")
latest = max(int(x[1:]) for x in versions)
selenium_logger = logging.getLogger(__name__)
selenium_logger.debug("Falling back to loading `devtools`: v%s", latest)
devtools = importlib.import_module(f"{base}{latest}")
return devtools
_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context")
_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context")
def get_connection_context(fn_name):
"""Look up the current connection.
If there is no current connection, raise a ``RuntimeError`` with a
helpful message.
"""
try:
return _connection_context.get()
except LookupError:
raise RuntimeError(f"{fn_name}() must be called in a connection context.")
def get_session_context(fn_name):
"""Look up the current session.
If there is no current session, raise a ``RuntimeError`` with a
helpful message.
"""
try:
return _session_context.get()
except LookupError:
raise RuntimeError(f"{fn_name}() must be called in a session context.")
@contextmanager
def connection_context(connection):
"""Context manager installs ``connection`` as the session context for the current Trio task."""
token = _connection_context.set(connection)
try:
yield
finally:
_connection_context.reset(token)
@contextmanager
def session_context(session):
"""Context manager installs ``session`` as the session context for the current Trio task."""
token = _session_context.set(session)
try:
yield
finally:
_session_context.reset(token)
def set_global_connection(connection):
"""Install ``connection`` in the root context so that it will become the default connection for all tasks.
This is generally not recommended, except it may be necessary in
certain use cases such as running inside Jupyter notebook.
"""
global _connection_context
_connection_context = contextvars.ContextVar("_connection_context", default=connection)
def set_global_session(session):
"""Install ``session`` in the root context so that it will become the default session for all tasks.
This is generally not recommended, except it may be necessary in
certain use cases such as running inside Jupyter notebook.
"""
global _session_context
_session_context = contextvars.ContextVar("_session_context", default=session)
class BrowserError(Exception):
"""This exception is raised when the browser's response to a command indicates that an error occurred."""
def __init__(self, obj):
self.code = obj.get("code")
self.message = obj.get("message")
self.detail = obj.get("data")
def __str__(self):
return f"BrowserError<code={self.code} message={self.message}> {self.detail}"
class CdpConnectionClosed(WsConnectionClosed):
"""Raised when a public method is called on a closed CDP connection."""
def __init__(self, reason):
"""Constructor.
Args:
reason: wsproto.frame_protocol.CloseReason
"""
self.reason = reason
def __repr__(self):
"""Return representation."""
return f"{self.__class__.__name__}<{self.reason}>"
class InternalError(Exception):
"""This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP."""
pass
@dataclass
class CmEventProxy:
"""A proxy object returned by :meth:`CdpBase.wait_for()``.
After the context manager executes, this proxy object will have a
value set that contains the returned event.
"""
value: Any = None
class CdpBase:
def __init__(self, ws, session_id, target_id):
self.ws = ws
self.session_id = session_id
self.target_id = target_id
self.channels = defaultdict(set)
self.id_iter = itertools.count()
self.inflight_cmd = {}
self.inflight_result = {}
async def execute(self, cmd: Generator[dict, T, Any]) -> T:
"""Execute a command on the server and wait for the result.
Args:
cmd: any CDP command
Returns:
a CDP result
"""
cmd_id = next(self.id_iter)
cmd_event = trio.Event()
self.inflight_cmd[cmd_id] = cmd, cmd_event
request = next(cmd)
request["id"] = cmd_id
if self.session_id:
request["sessionId"] = self.session_id
request_str = json.dumps(request)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}")
try:
await self.ws.send_message(request_str)
except WsConnectionClosed as wcc:
raise CdpConnectionClosed(wcc.reason) from None
await cmd_event.wait()
response = self.inflight_result.pop(cmd_id)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Received CDP message: {response}")
if isinstance(response, Exception):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}")
raise response
return response
def listen(self, *event_types, buffer_size=10):
"""Listen for events.
Returns:
An async iterator that iterates over events matching the indicated types.
"""
sender, receiver = trio.open_memory_channel(buffer_size)
for event_type in event_types:
self.channels[event_type].add(sender)
return receiver
@asynccontextmanager
async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]:
"""Wait for an event of the given type and return it.
This is an async context manager, so you should open it inside
an async with block. The block will not exit until the indicated
event is received.
"""
sender: trio.MemorySendChannel
receiver: trio.MemoryReceiveChannel
sender, receiver = trio.open_memory_channel(buffer_size)
self.channels[event_type].add(sender)
proxy = CmEventProxy()
yield proxy
async with receiver:
event = await receiver.receive()
proxy.value = event
def _handle_data(self, data):
"""Handle incoming WebSocket data.
Args:
data: a JSON dictionary
"""
if "id" in data:
self._handle_cmd_response(data)
else:
self._handle_event(data)
def _handle_cmd_response(self, data: dict):
"""Handle a response to a command.
This will set an event flag that will return control to the
task that called the command.
Args:
data: response as a JSON dictionary
"""
cmd_id = data["id"]
try:
cmd, event = self.inflight_cmd.pop(cmd_id)
except KeyError:
logger.warning("Got a message with a command ID that does not exist: %s", data)
return
if "error" in data:
# If the server reported an error, convert it to an exception and do
# not process the response any further.
self.inflight_result[cmd_id] = BrowserError(data["error"])
else:
# Otherwise, continue the generator to parse the JSON result
# into a CDP object.
try:
_ = cmd.send(data["result"])
raise InternalError("The command's generator function did not exit when expected!")
except StopIteration as exit:
return_ = exit.value
self.inflight_result[cmd_id] = return_
event.set()
def _handle_event(self, data: dict):
"""Handle an event.
Args:
data: event as a JSON dictionary
"""
global devtools
if devtools is None:
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
event = devtools.util.parse_json_event(data)
logger.debug("Received event: %s", event)
to_remove = set()
for sender in self.channels[type(event)]:
try:
sender.send_nowait(event)
except trio.WouldBlock:
logger.error('Unable to send event "%r" due to full channel %s', event, sender)
except trio.BrokenResourceError:
to_remove.add(sender)
if to_remove:
self.channels[type(event)] -= to_remove
class CdpSession(CdpBase):
"""Contains the state for a CDP session.
Generally you should not instantiate this object yourself; you should call
:meth:`CdpConnection.open_session`.
"""
def __init__(self, ws, session_id, target_id):
"""Constructor.
Args:
ws: trio_websocket.WebSocketConnection
session_id: devtools.target.SessionID
target_id: devtools.target.TargetID
"""
super().__init__(ws, session_id, target_id)
self._dom_enable_count = 0
self._dom_enable_lock = trio.Lock()
self._page_enable_count = 0
self._page_enable_lock = trio.Lock()
@asynccontextmanager
async def dom_enable(self):
"""Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``.
This keeps track of concurrent callers and only disables DOM
events when all callers have exited.
"""
global devtools
async with self._dom_enable_lock:
self._dom_enable_count += 1
if self._dom_enable_count == 1:
await self.execute(devtools.dom.enable())
yield
async with self._dom_enable_lock:
self._dom_enable_count -= 1
if self._dom_enable_count == 0:
await self.execute(devtools.dom.disable())
@asynccontextmanager
async def page_enable(self):
"""Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits.
This keeps track of concurrent callers and only disables page
events when all callers have exited.
"""
global devtools
async with self._page_enable_lock:
self._page_enable_count += 1
if self._page_enable_count == 1:
await self.execute(devtools.page.enable())
yield
async with self._page_enable_lock:
self._page_enable_count -= 1
if self._page_enable_count == 0:
await self.execute(devtools.page.disable())
class CdpConnection(CdpBase, trio.abc.AsyncResource):
"""Contains the connection state for a Chrome DevTools Protocol server.
CDP can multiplex multiple "sessions" over a single connection. This
class corresponds to the "root" session, i.e. the implicitly created
session that has no session ID. This class is responsible for
reading incoming WebSocket messages and forwarding them to the
corresponding session, as well as handling messages targeted at the
root session itself. You should generally call the
:func:`open_cdp()` instead of instantiating this class directly.
"""
def __init__(self, ws):
"""Constructor.
Args:
ws: trio_websocket.WebSocketConnection
"""
super().__init__(ws, session_id=None, target_id=None)
self.sessions = {}
async def aclose(self):
"""Close the underlying WebSocket connection.
This will cause the reader task to gracefully exit when it tries
to read the next message from the WebSocket. All of the public
APIs (``execute()``, ``listen()``, etc.) will raise
``CdpConnectionClosed`` after the CDP connection is closed. It
is safe to call this multiple times.
"""
await self.ws.aclose()
@asynccontextmanager
async def open_session(self, target_id) -> AsyncIterator[CdpSession]:
"""Context manager opens a session and enables the "simple" style of calling CDP APIs.
For example, inside a session context, you can call ``await
dom.get_document()`` and it will execute on the current session
automatically.
"""
session = await self.connect_session(target_id)
with session_context(session):
yield session
async def connect_session(self, target_id) -> "CdpSession":
"""Returns a new :class:`CdpSession` connected to the specified target."""
global devtools
if devtools is None:
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
session_id = await self.execute(devtools.target.attach_to_target(target_id, True))
session = CdpSession(self.ws, session_id, target_id)
self.sessions[session_id] = session
return session
async def _reader_task(self):
"""Runs in the background and handles incoming messages.
Dispatches responses to commands and events to listeners.
"""
global devtools
if devtools is None:
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
while True:
try:
message = await self.ws.get_message()
except WsConnectionClosed:
# If the WebSocket is closed, we don't want to throw an
# exception from the reader task. Instead we will throw
# exceptions from the public API methods, and we can quietly
# exit the reader task here.
break
try:
data = json.loads(message)
except json.JSONDecodeError:
raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message})
logger.debug("Received message %r", data)
if "sessionId" in data:
session_id = devtools.target.SessionID(data["sessionId"])
try:
session = self.sessions[session_id]
except KeyError:
raise BrowserError(
{
"code": -32700,
"message": "Browser sent a message for an invalid session",
"data": f"{session_id!r}",
}
)
session._handle_data(data)
else:
self._handle_data(data)
for _, session in self.sessions.items():
for _, senders in session.channels.items():
for sender in senders:
sender.close()
@asynccontextmanager
async def open_cdp(url) -> AsyncIterator[CdpConnection]:
"""Async context manager opens a connection to the browser then closes the connection when the block exits.
The context manager also sets the connection as the default
connection for the current task, so that commands like ``await
target.get_targets()`` will run on this connection automatically. If
you want to use multiple connections concurrently, it is recommended
to open each on in a separate task.
"""
async with trio.open_nursery() as nursery:
conn = await connect_cdp(nursery, url)
try:
with connection_context(conn):
yield conn
finally:
await conn.aclose()
async def connect_cdp(nursery, url) -> CdpConnection:
"""Connect to the browser specified by ``url`` and spawn a background task in the specified nursery.
The ``open_cdp()`` context manager is preferred in most situations.
You should only use this function if you need to specify a custom
nursery. This connection is not automatically closed! You can either
use the connection object as a context manager (``async with
conn:``) or else call ``await conn.aclose()`` on it when you are
done with it. If ``set_context`` is True, then the returned
connection will be installed as the default connection for the
current task. This argument is for unusual use cases, such as
running inside of a notebook.
"""
ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE)
cdp_conn = CdpConnection(ws)
nursery.start_soon(cdp_conn._reader_task)
return cdp_conn
@@ -0,0 +1,36 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Generator
def command_builder(method: str, params: dict | None = None) -> Generator[dict, dict, dict]:
"""Build a command iterator to send to the BiDi protocol.
Args:
method: The method to execute.
params: The parameters to pass to the method. Default is None.
Returns:
The response from the command execution.
"""
if params is None:
params = {}
command = {"method": method, "params": params}
cmd = yield command
return cmd
@@ -0,0 +1,24 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum
class Console(Enum):
ALL = "all"
LOG = "log"
ERROR = "error"
@@ -0,0 +1,524 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, TypeVar
from selenium.webdriver.common.bidi.common import command_builder
if TYPE_CHECKING:
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
class ScreenOrientationNatural(Enum):
"""Natural screen orientation."""
PORTRAIT = "portrait"
LANDSCAPE = "landscape"
class ScreenOrientationType(Enum):
"""Screen orientation type."""
PORTRAIT_PRIMARY = "portrait-primary"
PORTRAIT_SECONDARY = "portrait-secondary"
LANDSCAPE_PRIMARY = "landscape-primary"
LANDSCAPE_SECONDARY = "landscape-secondary"
E = TypeVar("E", ScreenOrientationNatural, ScreenOrientationType)
def _convert_to_enum(value: E | str, enum_class: type[E]) -> E:
if isinstance(value, enum_class):
return value
assert isinstance(value, str)
try:
return enum_class(value.lower())
except ValueError:
raise ValueError(f"Invalid orientation: {value}")
class ScreenOrientation:
"""Represents screen orientation configuration."""
def __init__(
self,
natural: ScreenOrientationNatural | str,
type: ScreenOrientationType | str,
):
"""Initialize ScreenOrientation.
Args:
natural: Natural screen orientation ("portrait" or "landscape").
type: Screen orientation type ("portrait-primary", "portrait-secondary",
"landscape-primary", or "landscape-secondary").
Raises:
ValueError: If natural or type values are invalid.
"""
# handle string values
self.natural = _convert_to_enum(natural, ScreenOrientationNatural)
self.type = _convert_to_enum(type, ScreenOrientationType)
def to_dict(self) -> dict[str, str]:
return {
"natural": self.natural.value,
"type": self.type.value,
}
class GeolocationCoordinates:
"""Represents geolocation coordinates."""
def __init__(
self,
latitude: float,
longitude: float,
accuracy: float = 1.0,
altitude: float | None = None,
altitude_accuracy: float | None = None,
heading: float | None = None,
speed: float | None = None,
):
"""Initialize GeolocationCoordinates.
Args:
latitude: Latitude coordinate (-90.0 to 90.0).
longitude: Longitude coordinate (-180.0 to 180.0).
accuracy: Accuracy in meters (>= 0.0), defaults to 1.0.
altitude: Altitude in meters or None, defaults to None.
altitude_accuracy: Altitude accuracy in meters (>= 0.0) or None, defaults to None.
heading: Heading in degrees (0.0 to 360.0) or None, defaults to None.
speed: Speed in meters per second (>= 0.0) or None, defaults to None.
Raises:
ValueError: If coordinates are out of valid range or if altitude_accuracy is provided without altitude.
"""
self.latitude = latitude
self.longitude = longitude
self.accuracy = accuracy
self.altitude = altitude
self.altitude_accuracy = altitude_accuracy
self.heading = heading
self.speed = speed
@property
def latitude(self) -> float:
return self._latitude
@latitude.setter
def latitude(self, value: float) -> None:
if not (-90.0 <= value <= 90.0):
raise ValueError("latitude must be between -90.0 and 90.0")
self._latitude = value
@property
def longitude(self) -> float:
return self._longitude
@longitude.setter
def longitude(self, value: float) -> None:
if not (-180.0 <= value <= 180.0):
raise ValueError("longitude must be between -180.0 and 180.0")
self._longitude = value
@property
def accuracy(self) -> float:
return self._accuracy
@accuracy.setter
def accuracy(self, value: float) -> None:
if value < 0.0:
raise ValueError("accuracy must be >= 0.0")
self._accuracy = value
@property
def altitude(self) -> float | None:
return self._altitude
@altitude.setter
def altitude(self, value: float | None) -> None:
self._altitude = value
@property
def altitude_accuracy(self) -> float | None:
return self._altitude_accuracy
@altitude_accuracy.setter
def altitude_accuracy(self, value: float | None) -> None:
if value is not None and self.altitude is None:
raise ValueError("altitude_accuracy cannot be set without altitude")
if value is not None and value < 0.0:
raise ValueError("altitude_accuracy must be >= 0.0")
self._altitude_accuracy = value
@property
def heading(self) -> float | None:
return self._heading
@heading.setter
def heading(self, value: float | None) -> None:
if value is not None and not (0.0 <= value < 360.0):
raise ValueError("heading must be between 0.0 and 360.0")
self._heading = value
@property
def speed(self) -> float | None:
return self._speed
@speed.setter
def speed(self, value: float | None) -> None:
if value is not None and value < 0.0:
raise ValueError("speed must be >= 0.0")
self._speed = value
def to_dict(self) -> dict[str, float | None]:
result: dict[str, float | None] = {
"latitude": self.latitude,
"longitude": self.longitude,
"accuracy": self.accuracy,
}
if self.altitude is not None:
result["altitude"] = self.altitude
if self.altitude_accuracy is not None:
result["altitudeAccuracy"] = self.altitude_accuracy
if self.heading is not None:
result["heading"] = self.heading
if self.speed is not None:
result["speed"] = self.speed
return result
class GeolocationPositionError:
"""Represents a geolocation position error."""
TYPE_POSITION_UNAVAILABLE = "positionUnavailable"
def __init__(self, type: str = TYPE_POSITION_UNAVAILABLE):
if type != self.TYPE_POSITION_UNAVAILABLE:
raise ValueError(f'type must be "{self.TYPE_POSITION_UNAVAILABLE}"')
self.type = type
def to_dict(self) -> dict[str, str]:
return {"type": self.type}
class Emulation:
"""BiDi implementation of the emulation module."""
def __init__(self, conn: WebSocketConnection) -> None:
self.conn = conn
def set_geolocation_override(
self,
coordinates: GeolocationCoordinates | None = None,
error: GeolocationPositionError | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set geolocation override for the given contexts or user contexts.
Args:
coordinates: Geolocation coordinates to emulate, or None.
error: Geolocation error to emulate, or None.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both coordinates and error are provided, or if both contexts
and user_contexts are provided, or if neither contexts nor
user_contexts are provided.
"""
if coordinates is not None and error is not None:
raise ValueError("Cannot specify both coordinates and error")
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {}
if coordinates is not None:
params["coordinates"] = coordinates.to_dict()
elif error is not None:
params["error"] = error.to_dict()
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setGeolocationOverride", params))
def set_timezone_override(
self,
timezone: str | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set timezone override for the given contexts or user contexts.
Args:
timezone: Timezone identifier (IANA timezone name or offset string like '+01:00'),
or None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
params: dict[str, Any] = {"timezone": timezone}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setTimezoneOverride", params))
def set_locale_override(
self,
locale: str | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set locale override for the given contexts or user contexts.
Args:
locale: Locale string as per BCP 47, or None to clear override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided, or if locale is invalid.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {"locale": locale}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setLocaleOverride", params))
def set_scripting_enabled(
self,
enabled: bool | None = False,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set scripting enabled override for the given contexts or user contexts.
Args:
enabled: False to disable scripting, None to clear the override.
Note: Only emulation of disabled JavaScript is supported.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided, or if enabled is True.
"""
if enabled:
raise ValueError("Only emulation of disabled JavaScript is supported (enabled must be False or None)")
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {"enabled": enabled}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setScriptingEnabled", params))
def set_screen_orientation_override(
self,
screen_orientation: ScreenOrientation | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set screen orientation override for the given contexts or user contexts.
Args:
screen_orientation: ScreenOrientation object to emulate, or None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {
"screenOrientation": screen_orientation.to_dict() if screen_orientation is not None else None
}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setScreenOrientationOverride", params))
def set_user_agent_override(
self,
user_agent: str | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set user agent override for the given contexts or user contexts.
Args:
user_agent: User agent string to emulate, or None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
params: dict[str, Any] = {"userAgent": user_agent}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setUserAgentOverride", params))
def set_network_conditions(
self,
offline: bool = False,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set network conditions for the given contexts or user contexts.
Args:
offline: True to emulate offline network conditions, False to clear the override.
contexts: List of browsing context IDs to apply the conditions to.
user_contexts: List of user context IDs to apply the conditions to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
params: dict[str, Any] = {}
if offline:
params["networkConditions"] = {"type": "offline"}
else:
# if offline is False or None, then clear the override
params["networkConditions"] = None
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setNetworkConditions", params))
def set_screen_settings_override(
self,
width: int | None = None,
height: int | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set screen settings override for the given contexts or user contexts.
Args:
width: Screen width in pixels (>= 0). None to clear the override.
height: Screen height in pixels (>= 0). None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If only one of width/height is provided, or if both contexts
and user_contexts are provided, or if neither is provided.
"""
if (width is None) != (height is None):
raise ValueError("Must provide both width and height, or neither to clear the override")
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
screen_area = None
if width is not None and height is not None:
if not isinstance(width, int) or not isinstance(height, int):
raise ValueError("width and height must be integers")
if width < 0 or height < 0:
raise ValueError("width and height must be >= 0")
screen_area = {"width": width, "height": height}
params: dict[str, Any] = {"screenArea": screen_area}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setScreenSettingsOverride", params))
@@ -0,0 +1,462 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
from dataclasses import dataclass, field
from typing import Any
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.common.bidi.session import Session
class PointerType:
"""Represents the possible pointer types."""
MOUSE = "mouse"
PEN = "pen"
TOUCH = "touch"
VALID_TYPES = {MOUSE, PEN, TOUCH}
class Origin:
"""Represents the possible origin types."""
VIEWPORT = "viewport"
POINTER = "pointer"
@dataclass
class ElementOrigin:
"""Represents an element origin for input actions."""
type: str
element: dict
def __init__(self, element_reference: dict):
self.type = "element"
self.element = element_reference
def to_dict(self) -> dict:
"""Convert the ElementOrigin to a dictionary."""
return {"type": self.type, "element": self.element}
@dataclass
class PointerParameters:
"""Represents pointer parameters for pointer actions."""
pointer_type: str = PointerType.MOUSE
def __post_init__(self):
if self.pointer_type not in PointerType.VALID_TYPES:
raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}")
def to_dict(self) -> dict:
"""Convert the PointerParameters to a dictionary."""
return {"pointerType": self.pointer_type}
@dataclass
class PointerCommonProperties:
"""Common properties for pointer actions."""
width: int = 1
height: int = 1
pressure: float = 0.0
tangential_pressure: float = 0.0
twist: int = 0
altitude_angle: float = 0.0
azimuth_angle: float = 0.0
def __post_init__(self):
if self.width < 1:
raise ValueError("width must be at least 1")
if self.height < 1:
raise ValueError("height must be at least 1")
if not (0.0 <= self.pressure <= 1.0):
raise ValueError("pressure must be between 0.0 and 1.0")
if not (0.0 <= self.tangential_pressure <= 1.0):
raise ValueError("tangential_pressure must be between 0.0 and 1.0")
if not (0 <= self.twist <= 359):
raise ValueError("twist must be between 0 and 359")
if not (0.0 <= self.altitude_angle <= math.pi / 2):
raise ValueError("altitude_angle must be between 0.0 and π/2")
if not (0.0 <= self.azimuth_angle <= 2 * math.pi):
raise ValueError("azimuth_angle must be between 0.0 and 2π")
def to_dict(self) -> dict:
"""Convert the PointerCommonProperties to a dictionary."""
result: dict[str, Any] = {}
if self.width != 1:
result["width"] = self.width
if self.height != 1:
result["height"] = self.height
if self.pressure != 0.0:
result["pressure"] = self.pressure
if self.tangential_pressure != 0.0:
result["tangentialPressure"] = self.tangential_pressure
if self.twist != 0:
result["twist"] = self.twist
if self.altitude_angle != 0.0:
result["altitudeAngle"] = self.altitude_angle
if self.azimuth_angle != 0.0:
result["azimuthAngle"] = self.azimuth_angle
return result
# Action classes
@dataclass
class PauseAction:
"""Represents a pause action."""
duration: int | None = None
@property
def type(self) -> str:
return "pause"
def to_dict(self) -> dict:
"""Convert the PauseAction to a dictionary."""
result: dict[str, Any] = {"type": self.type}
if self.duration is not None:
result["duration"] = self.duration
return result
@dataclass
class KeyDownAction:
"""Represents a key down action."""
value: str = ""
@property
def type(self) -> str:
return "keyDown"
def to_dict(self) -> dict:
"""Convert the KeyDownAction to a dictionary."""
return {"type": self.type, "value": self.value}
@dataclass
class KeyUpAction:
"""Represents a key up action."""
value: str = ""
@property
def type(self) -> str:
return "keyUp"
def to_dict(self) -> dict:
"""Convert the KeyUpAction to a dictionary."""
return {"type": self.type, "value": self.value}
@dataclass
class PointerDownAction:
"""Represents a pointer down action."""
button: int = 0
properties: PointerCommonProperties | None = None
@property
def type(self) -> str:
return "pointerDown"
def to_dict(self) -> dict:
"""Convert the PointerDownAction to a dictionary."""
result: dict[str, Any] = {"type": self.type, "button": self.button}
if self.properties:
result.update(self.properties.to_dict())
return result
@dataclass
class PointerUpAction:
"""Represents a pointer up action."""
button: int = 0
@property
def type(self) -> str:
return "pointerUp"
def to_dict(self) -> dict:
"""Convert the PointerUpAction to a dictionary."""
return {"type": self.type, "button": self.button}
@dataclass
class PointerMoveAction:
"""Represents a pointer move action."""
x: float = 0
y: float = 0
duration: int | None = None
origin: str | ElementOrigin | None = None
properties: PointerCommonProperties | None = None
@property
def type(self) -> str:
return "pointerMove"
def to_dict(self) -> dict:
"""Convert the PointerMoveAction to a dictionary."""
result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y}
if self.duration is not None:
result["duration"] = self.duration
if self.origin is not None:
if isinstance(self.origin, ElementOrigin):
result["origin"] = self.origin.to_dict()
else:
result["origin"] = self.origin
if self.properties:
result.update(self.properties.to_dict())
return result
@dataclass
class WheelScrollAction:
"""Represents a wheel scroll action."""
x: int = 0
y: int = 0
delta_x: int = 0
delta_y: int = 0
duration: int | None = None
origin: str | ElementOrigin | None = Origin.VIEWPORT
@property
def type(self) -> str:
return "scroll"
def to_dict(self) -> dict:
"""Convert the WheelScrollAction to a dictionary."""
result: dict[str, Any] = {
"type": self.type,
"x": self.x,
"y": self.y,
"deltaX": self.delta_x,
"deltaY": self.delta_y,
}
if self.duration is not None:
result["duration"] = self.duration
if self.origin is not None:
if isinstance(self.origin, ElementOrigin):
result["origin"] = self.origin.to_dict()
else:
result["origin"] = self.origin
return result
# Source Actions
@dataclass
class NoneSourceActions:
"""Represents a sequence of none actions."""
id: str = ""
actions: list[PauseAction] = field(default_factory=list)
@property
def type(self) -> str:
return "none"
def to_dict(self) -> dict:
"""Convert the NoneSourceActions to a dictionary."""
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
@dataclass
class KeySourceActions:
"""Represents a sequence of key actions."""
id: str = ""
actions: list[PauseAction | KeyDownAction | KeyUpAction] = field(default_factory=list)
@property
def type(self) -> str:
return "key"
def to_dict(self) -> dict:
"""Convert the KeySourceActions to a dictionary."""
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
@dataclass
class PointerSourceActions:
"""Represents a sequence of pointer actions."""
id: str = ""
parameters: PointerParameters | None = None
actions: list[PauseAction | PointerDownAction | PointerUpAction | PointerMoveAction] = field(default_factory=list)
def __post_init__(self):
if self.parameters is None:
self.parameters = PointerParameters()
@property
def type(self) -> str:
return "pointer"
def to_dict(self) -> dict:
"""Convert the PointerSourceActions to a dictionary."""
result: dict[str, Any] = {
"type": self.type,
"id": self.id,
"actions": [action.to_dict() for action in self.actions],
}
if self.parameters:
result["parameters"] = self.parameters.to_dict()
return result
@dataclass
class WheelSourceActions:
"""Represents a sequence of wheel actions."""
id: str = ""
actions: list[PauseAction | WheelScrollAction] = field(default_factory=list)
@property
def type(self) -> str:
return "wheel"
def to_dict(self) -> dict:
"""Convert the WheelSourceActions to a dictionary."""
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
@dataclass
class FileDialogInfo:
"""Represents file dialog information from input.fileDialogOpened event."""
context: str
multiple: bool
element: dict | None = None
@classmethod
def from_dict(cls, data: dict) -> "FileDialogInfo":
"""Creates a FileDialogInfo instance from a dictionary.
Args:
data: A dictionary containing the file dialog information.
Returns:
FileDialogInfo: A new instance of FileDialogInfo.
"""
return cls(context=data["context"], multiple=data["multiple"], element=data.get("element"))
# Event Class
class FileDialogOpened:
"""Event class for input.fileDialogOpened event."""
event_class = "input.fileDialogOpened"
@classmethod
def from_json(cls, json):
"""Create FileDialogInfo from JSON data."""
return FileDialogInfo.from_dict(json)
class Input:
"""BiDi implementation of the input module."""
def __init__(self, conn):
self.conn = conn
self.subscriptions = {}
self.callbacks = {}
def perform_actions(
self,
context: str,
actions: list[NoneSourceActions | KeySourceActions | PointerSourceActions | WheelSourceActions],
) -> None:
"""Performs a sequence of user input actions.
Args:
context: The browsing context ID where actions should be performed.
actions: A list of source actions to perform.
"""
params = {"context": context, "actions": [action.to_dict() for action in actions]}
self.conn.execute(command_builder("input.performActions", params))
def release_actions(self, context: str) -> None:
"""Releases all input state for the given context.
Args:
context: The browsing context ID to release actions for.
"""
params = {"context": context}
self.conn.execute(command_builder("input.releaseActions", params))
def set_files(self, context: str, element: dict, files: list[str]) -> None:
"""Sets files for a file input element.
Args:
context: The browsing context ID.
element: The element reference (script.SharedReference).
files: A list of file paths to set.
"""
params = {"context": context, "element": element, "files": files}
self.conn.execute(command_builder("input.setFiles", params))
def add_file_dialog_handler(self, handler) -> int:
"""Add a handler for file dialog opened events.
Args:
handler: Callback function that takes a FileDialogInfo object.
Returns:
int: Callback ID for removing the handler later.
"""
# Subscribe to the event if not already subscribed
if FileDialogOpened.event_class not in self.subscriptions:
session = Session(self.conn)
self.conn.execute(session.subscribe(FileDialogOpened.event_class))
self.subscriptions[FileDialogOpened.event_class] = []
# Add callback - the callback receives the parsed FileDialogInfo directly
callback_id = self.conn.add_callback(FileDialogOpened, handler)
self.subscriptions[FileDialogOpened.event_class].append(callback_id)
self.callbacks[callback_id] = handler
return callback_id
def remove_file_dialog_handler(self, callback_id: int) -> None:
"""Remove a file dialog handler.
Args:
callback_id: The callback ID returned by add_file_dialog_handler.
"""
if callback_id in self.callbacks:
del self.callbacks[callback_id]
if FileDialogOpened.event_class in self.subscriptions:
if callback_id in self.subscriptions[FileDialogOpened.event_class]:
self.subscriptions[FileDialogOpened.event_class].remove(callback_id)
# If no more callbacks for this event, unsubscribe
if not self.subscriptions[FileDialogOpened.event_class]:
session = Session(self.conn)
self.conn.execute(session.unsubscribe(FileDialogOpened.event_class))
del self.subscriptions[FileDialogOpened.event_class]
self.conn.remove_callback(FileDialogOpened, callback_id)
@@ -0,0 +1,81 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
class LogEntryAdded:
event_class = "log.entryAdded"
@classmethod
def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry | JavaScriptLogEntry | None:
if json["type"] == "console":
return ConsoleLogEntry.from_json(json)
elif json["type"] == "javascript":
return JavaScriptLogEntry.from_json(json)
return None
@dataclass
class ConsoleLogEntry:
level: str
text: str
timestamp: str
method: str
args: list[dict[str, Any]]
type_: str
@classmethod
def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry:
return cls(
level=json["level"],
text=json["text"],
timestamp=json["timestamp"],
method=json["method"],
args=json["args"],
type_=json["type"],
)
@dataclass
class JavaScriptLogEntry:
level: str
text: str
timestamp: str
stacktrace: dict[str, Any]
type_: str
@classmethod
def from_json(cls, json: dict[str, Any]) -> JavaScriptLogEntry:
return cls(
level=json["level"],
text=json["text"],
timestamp=json["timestamp"],
stacktrace=json["stackTrace"],
type_=json["type"],
)
class LogLevel:
"""Represents log level."""
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"

Some files were not shown because too many files have changed in this diff Show More