initial commit

This commit is contained in:
2026-06-25 21:29:21 +00:00
commit 0d0a7456de
2738 changed files with 542622 additions and 0 deletions
@@ -0,0 +1,16 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
@@ -0,0 +1,280 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from typing import Any
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.common.bidi.session import UserPromptHandler
from selenium.webdriver.common.proxy import Proxy
class ClientWindowState:
"""Represents a window state."""
FULLSCREEN = "fullscreen"
MAXIMIZED = "maximized"
MINIMIZED = "minimized"
NORMAL = "normal"
VALID_STATES = {FULLSCREEN, MAXIMIZED, MINIMIZED, NORMAL}
class ClientWindowInfo:
"""Represents a client window information."""
def __init__(
self,
client_window: str,
state: str,
width: int,
height: int,
x: int,
y: int,
active: bool,
):
self.client_window = client_window
self.state = state
self.width = width
self.height = height
self.x = x
self.y = y
self.active = active
def get_state(self) -> str:
"""Gets the state of the client window.
Returns:
str: The state of the client window (one of the ClientWindowState constants).
"""
return self.state
def get_client_window(self) -> str:
"""Gets the client window identifier.
Returns:
str: The client window identifier.
"""
return self.client_window
def get_width(self) -> int:
"""Gets the width of the client window.
Returns:
int: The width of the client window.
"""
return self.width
def get_height(self) -> int:
"""Gets the height of the client window.
Returns:
int: The height of the client window.
"""
return self.height
def get_x(self) -> int:
"""Gets the x coordinate of the client window.
Returns:
int: The x coordinate of the client window.
"""
return self.x
def get_y(self) -> int:
"""Gets the y coordinate of the client window.
Returns:
int: The y coordinate of the client window.
"""
return self.y
def is_active(self) -> bool:
"""Checks if the client window is active.
Returns:
bool: True if the client window is active, False otherwise.
"""
return self.active
@classmethod
def from_dict(cls, data: dict) -> "ClientWindowInfo":
"""Creates a ClientWindowInfo instance from a dictionary.
Args:
data: A dictionary containing the client window information.
Returns:
ClientWindowInfo: A new instance of ClientWindowInfo.
Raises:
ValueError: If required fields are missing or have invalid types.
"""
try:
client_window = data["clientWindow"]
if not isinstance(client_window, str):
raise ValueError("clientWindow must be a string")
state = data["state"]
if not isinstance(state, str):
raise ValueError("state must be a string")
if state not in ClientWindowState.VALID_STATES:
raise ValueError(f"Invalid state: {state}. Must be one of {ClientWindowState.VALID_STATES}")
width = data["width"]
if not isinstance(width, int) or width < 0:
raise ValueError(f"width must be a non-negative integer, got {width}")
height = data["height"]
if not isinstance(height, int) or height < 0:
raise ValueError(f"height must be a non-negative integer, got {height}")
x = data["x"]
if not isinstance(x, int):
raise ValueError(f"x must be an integer, got {type(x).__name__}")
y = data["y"]
if not isinstance(y, int):
raise ValueError(f"y must be an integer, got {type(y).__name__}")
active = data["active"]
if not isinstance(active, bool):
raise ValueError("active must be a boolean")
return cls(
client_window=client_window,
state=state,
width=width,
height=height,
x=x,
y=y,
active=active,
)
except (KeyError, TypeError) as e:
raise ValueError(f"Invalid data format for ClientWindowInfo: {e}") from e
class Browser:
"""BiDi implementation of the browser module."""
def __init__(self, conn):
self.conn = conn
def create_user_context(
self,
accept_insecure_certs: bool | None = None,
proxy: Proxy | None = None,
unhandled_prompt_behavior: UserPromptHandler | None = None,
) -> str:
"""Creates a new user context.
Args:
accept_insecure_certs: Optional flag to accept insecure TLS certificates.
proxy: Optional proxy configuration for the user context.
unhandled_prompt_behavior: Optional configuration for handling user prompts.
Returns:
str: The ID of the created user context.
"""
params: dict[str, Any] = {}
if accept_insecure_certs is not None:
params["acceptInsecureCerts"] = accept_insecure_certs
if proxy is not None:
params["proxy"] = proxy.to_bidi_dict()
if unhandled_prompt_behavior is not None:
params["unhandledPromptBehavior"] = unhandled_prompt_behavior.to_dict()
result = self.conn.execute(command_builder("browser.createUserContext", params))
return result["userContext"]
def get_user_contexts(self) -> list[str]:
"""Gets all user contexts.
Returns:
List[str]: A list of user context IDs.
"""
result = self.conn.execute(command_builder("browser.getUserContexts", {}))
return [context_info["userContext"] for context_info in result["userContexts"]]
def remove_user_context(self, user_context_id: str) -> None:
"""Removes a user context.
Args:
user_context_id: The ID of the user context to remove.
Raises:
ValueError: If the user context ID is "default" or does not exist.
"""
if user_context_id == "default":
raise ValueError("Cannot remove the default user context")
params = {"userContext": user_context_id}
self.conn.execute(command_builder("browser.removeUserContext", params))
def get_client_windows(self) -> list[ClientWindowInfo]:
"""Gets all client windows.
Returns:
List[ClientWindowInfo]: A list of client window information.
"""
result = self.conn.execute(command_builder("browser.getClientWindows", {}))
return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]]
def set_download_behavior(
self,
*,
allowed: bool | None = None,
destination_folder: str | os.PathLike | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set the download behavior for the browser or specific user contexts.
Args:
allowed: True to allow downloads, False to deny downloads, or None to
clear download behavior (revert to default).
destination_folder: Required when allowed is True. Specifies the folder
to store downloads in.
user_contexts: Optional list of user context IDs to apply this
behavior to. If omitted, updates the default behavior.
Raises:
ValueError: If allowed=True and destination_folder is missing, or if
allowed=False and destination_folder is provided.
"""
params: dict[str, Any] = {}
if allowed is None:
params["downloadBehavior"] = None
else:
if allowed:
if not destination_folder:
raise ValueError("destination_folder is required when allowed=True.")
params["downloadBehavior"] = {
"type": "allowed",
"destinationFolder": os.fspath(destination_folder),
}
else:
if destination_folder:
raise ValueError("destination_folder should not be provided when allowed=False.")
params["downloadBehavior"] = {"type": "denied"}
if user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("browser.setDownloadBehavior", params))
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,515 @@
# The MIT License(MIT)
#
# Copyright(c) 2018 Hyperion Gray
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files(the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp
import contextvars
import importlib
import itertools
import json
import logging
import pathlib
from collections import defaultdict
from collections.abc import AsyncGenerator, AsyncIterator, Generator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from typing import Any, TypeVar
import trio
from trio_websocket import ConnectionClosed as WsConnectionClosed
from trio_websocket import connect_websocket_url
logger = logging.getLogger("trio_cdp")
T = TypeVar("T")
MAX_WS_MESSAGE_SIZE = 2**24
devtools = None
version = None
def import_devtools(ver):
"""Attempt to load the current latest available devtools into the module cache for use later."""
global devtools
global version
version = ver
base = "selenium.webdriver.common.devtools.v"
try:
devtools = importlib.import_module(f"{base}{ver}")
return devtools
except ModuleNotFoundError:
# Attempt to parse and load the 'most recent' devtools module. This is likely
# because cdp has been updated but selenium python has not been released yet.
devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools")
versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir() and f.name != "latest")
latest = max(int(x[1:]) for x in versions)
selenium_logger = logging.getLogger(__name__)
selenium_logger.debug("Falling back to loading `devtools`: v%s", latest)
devtools = importlib.import_module(f"{base}{latest}")
return devtools
_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context")
_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context")
def get_connection_context(fn_name):
"""Look up the current connection.
If there is no current connection, raise a ``RuntimeError`` with a
helpful message.
"""
try:
return _connection_context.get()
except LookupError:
raise RuntimeError(f"{fn_name}() must be called in a connection context.")
def get_session_context(fn_name):
"""Look up the current session.
If there is no current session, raise a ``RuntimeError`` with a
helpful message.
"""
try:
return _session_context.get()
except LookupError:
raise RuntimeError(f"{fn_name}() must be called in a session context.")
@contextmanager
def connection_context(connection):
"""Context manager installs ``connection`` as the session context for the current Trio task."""
token = _connection_context.set(connection)
try:
yield
finally:
_connection_context.reset(token)
@contextmanager
def session_context(session):
"""Context manager installs ``session`` as the session context for the current Trio task."""
token = _session_context.set(session)
try:
yield
finally:
_session_context.reset(token)
def set_global_connection(connection):
"""Install ``connection`` in the root context so that it will become the default connection for all tasks.
This is generally not recommended, except it may be necessary in
certain use cases such as running inside Jupyter notebook.
"""
global _connection_context
_connection_context = contextvars.ContextVar("_connection_context", default=connection)
def set_global_session(session):
"""Install ``session`` in the root context so that it will become the default session for all tasks.
This is generally not recommended, except it may be necessary in
certain use cases such as running inside Jupyter notebook.
"""
global _session_context
_session_context = contextvars.ContextVar("_session_context", default=session)
class BrowserError(Exception):
"""This exception is raised when the browser's response to a command indicates that an error occurred."""
def __init__(self, obj):
self.code = obj.get("code")
self.message = obj.get("message")
self.detail = obj.get("data")
def __str__(self):
return f"BrowserError<code={self.code} message={self.message}> {self.detail}"
class CdpConnectionClosed(WsConnectionClosed):
"""Raised when a public method is called on a closed CDP connection."""
def __init__(self, reason):
"""Constructor.
Args:
reason: wsproto.frame_protocol.CloseReason
"""
self.reason = reason
def __repr__(self):
"""Return representation."""
return f"{self.__class__.__name__}<{self.reason}>"
class InternalError(Exception):
"""This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP."""
pass
@dataclass
class CmEventProxy:
"""A proxy object returned by :meth:`CdpBase.wait_for()``.
After the context manager executes, this proxy object will have a
value set that contains the returned event.
"""
value: Any = None
class CdpBase:
def __init__(self, ws, session_id, target_id):
self.ws = ws
self.session_id = session_id
self.target_id = target_id
self.channels = defaultdict(set)
self.id_iter = itertools.count()
self.inflight_cmd = {}
self.inflight_result = {}
async def execute(self, cmd: Generator[dict, T, Any]) -> T:
"""Execute a command on the server and wait for the result.
Args:
cmd: any CDP command
Returns:
a CDP result
"""
cmd_id = next(self.id_iter)
cmd_event = trio.Event()
self.inflight_cmd[cmd_id] = cmd, cmd_event
request = next(cmd)
request["id"] = cmd_id
if self.session_id:
request["sessionId"] = self.session_id
request_str = json.dumps(request)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}")
try:
await self.ws.send_message(request_str)
except WsConnectionClosed as wcc:
raise CdpConnectionClosed(wcc.reason) from None
await cmd_event.wait()
response = self.inflight_result.pop(cmd_id)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Received CDP message: {response}")
if isinstance(response, Exception):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}")
raise response
return response
def listen(self, *event_types, buffer_size=10):
"""Listen for events.
Returns:
An async iterator that iterates over events matching the indicated types.
"""
sender, receiver = trio.open_memory_channel(buffer_size)
for event_type in event_types:
self.channels[event_type].add(sender)
return receiver
@asynccontextmanager
async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]:
"""Wait for an event of the given type and return it.
This is an async context manager, so you should open it inside
an async with block. The block will not exit until the indicated
event is received.
"""
sender: trio.MemorySendChannel
receiver: trio.MemoryReceiveChannel
sender, receiver = trio.open_memory_channel(buffer_size)
self.channels[event_type].add(sender)
proxy = CmEventProxy()
yield proxy
async with receiver:
event = await receiver.receive()
proxy.value = event
def _handle_data(self, data):
"""Handle incoming WebSocket data.
Args:
data: a JSON dictionary
"""
if "id" in data:
self._handle_cmd_response(data)
else:
self._handle_event(data)
def _handle_cmd_response(self, data: dict):
"""Handle a response to a command.
This will set an event flag that will return control to the
task that called the command.
Args:
data: response as a JSON dictionary
"""
cmd_id = data["id"]
try:
cmd, event = self.inflight_cmd.pop(cmd_id)
except KeyError:
logger.warning("Got a message with a command ID that does not exist: %s", data)
return
if "error" in data:
# If the server reported an error, convert it to an exception and do
# not process the response any further.
self.inflight_result[cmd_id] = BrowserError(data["error"])
else:
# Otherwise, continue the generator to parse the JSON result
# into a CDP object.
try:
_ = cmd.send(data["result"])
raise InternalError("The command's generator function did not exit when expected!")
except StopIteration as exit:
return_ = exit.value
self.inflight_result[cmd_id] = return_
event.set()
def _handle_event(self, data: dict):
"""Handle an event.
Args:
data: event as a JSON dictionary
"""
global devtools
if devtools is None:
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
event = devtools.util.parse_json_event(data)
logger.debug("Received event: %s", event)
to_remove = set()
for sender in self.channels[type(event)]:
try:
sender.send_nowait(event)
except trio.WouldBlock:
logger.error('Unable to send event "%r" due to full channel %s', event, sender)
except trio.BrokenResourceError:
to_remove.add(sender)
if to_remove:
self.channels[type(event)] -= to_remove
class CdpSession(CdpBase):
"""Contains the state for a CDP session.
Generally you should not instantiate this object yourself; you should call
:meth:`CdpConnection.open_session`.
"""
def __init__(self, ws, session_id, target_id):
"""Constructor.
Args:
ws: trio_websocket.WebSocketConnection
session_id: devtools.target.SessionID
target_id: devtools.target.TargetID
"""
super().__init__(ws, session_id, target_id)
self._dom_enable_count = 0
self._dom_enable_lock = trio.Lock()
self._page_enable_count = 0
self._page_enable_lock = trio.Lock()
@asynccontextmanager
async def dom_enable(self):
"""Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``.
This keeps track of concurrent callers and only disables DOM
events when all callers have exited.
"""
global devtools
async with self._dom_enable_lock:
self._dom_enable_count += 1
if self._dom_enable_count == 1:
await self.execute(devtools.dom.enable())
yield
async with self._dom_enable_lock:
self._dom_enable_count -= 1
if self._dom_enable_count == 0:
await self.execute(devtools.dom.disable())
@asynccontextmanager
async def page_enable(self):
"""Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits.
This keeps track of concurrent callers and only disables page
events when all callers have exited.
"""
global devtools
async with self._page_enable_lock:
self._page_enable_count += 1
if self._page_enable_count == 1:
await self.execute(devtools.page.enable())
yield
async with self._page_enable_lock:
self._page_enable_count -= 1
if self._page_enable_count == 0:
await self.execute(devtools.page.disable())
class CdpConnection(CdpBase, trio.abc.AsyncResource):
"""Contains the connection state for a Chrome DevTools Protocol server.
CDP can multiplex multiple "sessions" over a single connection. This
class corresponds to the "root" session, i.e. the implicitly created
session that has no session ID. This class is responsible for
reading incoming WebSocket messages and forwarding them to the
corresponding session, as well as handling messages targeted at the
root session itself. You should generally call the
:func:`open_cdp()` instead of instantiating this class directly.
"""
def __init__(self, ws):
"""Constructor.
Args:
ws: trio_websocket.WebSocketConnection
"""
super().__init__(ws, session_id=None, target_id=None)
self.sessions = {}
async def aclose(self):
"""Close the underlying WebSocket connection.
This will cause the reader task to gracefully exit when it tries
to read the next message from the WebSocket. All of the public
APIs (``execute()``, ``listen()``, etc.) will raise
``CdpConnectionClosed`` after the CDP connection is closed. It
is safe to call this multiple times.
"""
await self.ws.aclose()
@asynccontextmanager
async def open_session(self, target_id) -> AsyncIterator[CdpSession]:
"""Context manager opens a session and enables the "simple" style of calling CDP APIs.
For example, inside a session context, you can call ``await
dom.get_document()`` and it will execute on the current session
automatically.
"""
session = await self.connect_session(target_id)
with session_context(session):
yield session
async def connect_session(self, target_id) -> "CdpSession":
"""Returns a new :class:`CdpSession` connected to the specified target."""
global devtools
if devtools is None:
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
session_id = await self.execute(devtools.target.attach_to_target(target_id, True))
session = CdpSession(self.ws, session_id, target_id)
self.sessions[session_id] = session
return session
async def _reader_task(self):
"""Runs in the background and handles incoming messages.
Dispatches responses to commands and events to listeners.
"""
global devtools
if devtools is None:
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
while True:
try:
message = await self.ws.get_message()
except WsConnectionClosed:
# If the WebSocket is closed, we don't want to throw an
# exception from the reader task. Instead we will throw
# exceptions from the public API methods, and we can quietly
# exit the reader task here.
break
try:
data = json.loads(message)
except json.JSONDecodeError:
raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message})
logger.debug("Received message %r", data)
if "sessionId" in data:
session_id = devtools.target.SessionID(data["sessionId"])
try:
session = self.sessions[session_id]
except KeyError:
raise BrowserError(
{
"code": -32700,
"message": "Browser sent a message for an invalid session",
"data": f"{session_id!r}",
}
)
session._handle_data(data)
else:
self._handle_data(data)
for _, session in self.sessions.items():
for _, senders in session.channels.items():
for sender in senders:
sender.close()
@asynccontextmanager
async def open_cdp(url) -> AsyncIterator[CdpConnection]:
"""Async context manager opens a connection to the browser then closes the connection when the block exits.
The context manager also sets the connection as the default
connection for the current task, so that commands like ``await
target.get_targets()`` will run on this connection automatically. If
you want to use multiple connections concurrently, it is recommended
to open each on in a separate task.
"""
async with trio.open_nursery() as nursery:
conn = await connect_cdp(nursery, url)
try:
with connection_context(conn):
yield conn
finally:
await conn.aclose()
async def connect_cdp(nursery, url) -> CdpConnection:
"""Connect to the browser specified by ``url`` and spawn a background task in the specified nursery.
The ``open_cdp()`` context manager is preferred in most situations.
You should only use this function if you need to specify a custom
nursery. This connection is not automatically closed! You can either
use the connection object as a context manager (``async with
conn:``) or else call ``await conn.aclose()`` on it when you are
done with it. If ``set_context`` is True, then the returned
connection will be installed as the default connection for the
current task. This argument is for unusual use cases, such as
running inside of a notebook.
"""
ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE)
cdp_conn = CdpConnection(ws)
nursery.start_soon(cdp_conn._reader_task)
return cdp_conn
@@ -0,0 +1,36 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Generator
def command_builder(method: str, params: dict | None = None) -> Generator[dict, dict, dict]:
"""Build a command iterator to send to the BiDi protocol.
Args:
method: The method to execute.
params: The parameters to pass to the method. Default is None.
Returns:
The response from the command execution.
"""
if params is None:
params = {}
command = {"method": method, "params": params}
cmd = yield command
return cmd
@@ -0,0 +1,24 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum
class Console(Enum):
ALL = "all"
LOG = "log"
ERROR = "error"
@@ -0,0 +1,524 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, TypeVar
from selenium.webdriver.common.bidi.common import command_builder
if TYPE_CHECKING:
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
class ScreenOrientationNatural(Enum):
"""Natural screen orientation."""
PORTRAIT = "portrait"
LANDSCAPE = "landscape"
class ScreenOrientationType(Enum):
"""Screen orientation type."""
PORTRAIT_PRIMARY = "portrait-primary"
PORTRAIT_SECONDARY = "portrait-secondary"
LANDSCAPE_PRIMARY = "landscape-primary"
LANDSCAPE_SECONDARY = "landscape-secondary"
E = TypeVar("E", ScreenOrientationNatural, ScreenOrientationType)
def _convert_to_enum(value: E | str, enum_class: type[E]) -> E:
if isinstance(value, enum_class):
return value
assert isinstance(value, str)
try:
return enum_class(value.lower())
except ValueError:
raise ValueError(f"Invalid orientation: {value}")
class ScreenOrientation:
"""Represents screen orientation configuration."""
def __init__(
self,
natural: ScreenOrientationNatural | str,
type: ScreenOrientationType | str,
):
"""Initialize ScreenOrientation.
Args:
natural: Natural screen orientation ("portrait" or "landscape").
type: Screen orientation type ("portrait-primary", "portrait-secondary",
"landscape-primary", or "landscape-secondary").
Raises:
ValueError: If natural or type values are invalid.
"""
# handle string values
self.natural = _convert_to_enum(natural, ScreenOrientationNatural)
self.type = _convert_to_enum(type, ScreenOrientationType)
def to_dict(self) -> dict[str, str]:
return {
"natural": self.natural.value,
"type": self.type.value,
}
class GeolocationCoordinates:
"""Represents geolocation coordinates."""
def __init__(
self,
latitude: float,
longitude: float,
accuracy: float = 1.0,
altitude: float | None = None,
altitude_accuracy: float | None = None,
heading: float | None = None,
speed: float | None = None,
):
"""Initialize GeolocationCoordinates.
Args:
latitude: Latitude coordinate (-90.0 to 90.0).
longitude: Longitude coordinate (-180.0 to 180.0).
accuracy: Accuracy in meters (>= 0.0), defaults to 1.0.
altitude: Altitude in meters or None, defaults to None.
altitude_accuracy: Altitude accuracy in meters (>= 0.0) or None, defaults to None.
heading: Heading in degrees (0.0 to 360.0) or None, defaults to None.
speed: Speed in meters per second (>= 0.0) or None, defaults to None.
Raises:
ValueError: If coordinates are out of valid range or if altitude_accuracy is provided without altitude.
"""
self.latitude = latitude
self.longitude = longitude
self.accuracy = accuracy
self.altitude = altitude
self.altitude_accuracy = altitude_accuracy
self.heading = heading
self.speed = speed
@property
def latitude(self) -> float:
return self._latitude
@latitude.setter
def latitude(self, value: float) -> None:
if not (-90.0 <= value <= 90.0):
raise ValueError("latitude must be between -90.0 and 90.0")
self._latitude = value
@property
def longitude(self) -> float:
return self._longitude
@longitude.setter
def longitude(self, value: float) -> None:
if not (-180.0 <= value <= 180.0):
raise ValueError("longitude must be between -180.0 and 180.0")
self._longitude = value
@property
def accuracy(self) -> float:
return self._accuracy
@accuracy.setter
def accuracy(self, value: float) -> None:
if value < 0.0:
raise ValueError("accuracy must be >= 0.0")
self._accuracy = value
@property
def altitude(self) -> float | None:
return self._altitude
@altitude.setter
def altitude(self, value: float | None) -> None:
self._altitude = value
@property
def altitude_accuracy(self) -> float | None:
return self._altitude_accuracy
@altitude_accuracy.setter
def altitude_accuracy(self, value: float | None) -> None:
if value is not None and self.altitude is None:
raise ValueError("altitude_accuracy cannot be set without altitude")
if value is not None and value < 0.0:
raise ValueError("altitude_accuracy must be >= 0.0")
self._altitude_accuracy = value
@property
def heading(self) -> float | None:
return self._heading
@heading.setter
def heading(self, value: float | None) -> None:
if value is not None and not (0.0 <= value < 360.0):
raise ValueError("heading must be between 0.0 and 360.0")
self._heading = value
@property
def speed(self) -> float | None:
return self._speed
@speed.setter
def speed(self, value: float | None) -> None:
if value is not None and value < 0.0:
raise ValueError("speed must be >= 0.0")
self._speed = value
def to_dict(self) -> dict[str, float | None]:
result: dict[str, float | None] = {
"latitude": self.latitude,
"longitude": self.longitude,
"accuracy": self.accuracy,
}
if self.altitude is not None:
result["altitude"] = self.altitude
if self.altitude_accuracy is not None:
result["altitudeAccuracy"] = self.altitude_accuracy
if self.heading is not None:
result["heading"] = self.heading
if self.speed is not None:
result["speed"] = self.speed
return result
class GeolocationPositionError:
"""Represents a geolocation position error."""
TYPE_POSITION_UNAVAILABLE = "positionUnavailable"
def __init__(self, type: str = TYPE_POSITION_UNAVAILABLE):
if type != self.TYPE_POSITION_UNAVAILABLE:
raise ValueError(f'type must be "{self.TYPE_POSITION_UNAVAILABLE}"')
self.type = type
def to_dict(self) -> dict[str, str]:
return {"type": self.type}
class Emulation:
"""BiDi implementation of the emulation module."""
def __init__(self, conn: WebSocketConnection) -> None:
self.conn = conn
def set_geolocation_override(
self,
coordinates: GeolocationCoordinates | None = None,
error: GeolocationPositionError | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set geolocation override for the given contexts or user contexts.
Args:
coordinates: Geolocation coordinates to emulate, or None.
error: Geolocation error to emulate, or None.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both coordinates and error are provided, or if both contexts
and user_contexts are provided, or if neither contexts nor
user_contexts are provided.
"""
if coordinates is not None and error is not None:
raise ValueError("Cannot specify both coordinates and error")
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {}
if coordinates is not None:
params["coordinates"] = coordinates.to_dict()
elif error is not None:
params["error"] = error.to_dict()
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setGeolocationOverride", params))
def set_timezone_override(
self,
timezone: str | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set timezone override for the given contexts or user contexts.
Args:
timezone: Timezone identifier (IANA timezone name or offset string like '+01:00'),
or None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
params: dict[str, Any] = {"timezone": timezone}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setTimezoneOverride", params))
def set_locale_override(
self,
locale: str | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set locale override for the given contexts or user contexts.
Args:
locale: Locale string as per BCP 47, or None to clear override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided, or if locale is invalid.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {"locale": locale}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setLocaleOverride", params))
def set_scripting_enabled(
self,
enabled: bool | None = False,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set scripting enabled override for the given contexts or user contexts.
Args:
enabled: False to disable scripting, None to clear the override.
Note: Only emulation of disabled JavaScript is supported.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided, or if enabled is True.
"""
if enabled:
raise ValueError("Only emulation of disabled JavaScript is supported (enabled must be False or None)")
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {"enabled": enabled}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setScriptingEnabled", params))
def set_screen_orientation_override(
self,
screen_orientation: ScreenOrientation | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set screen orientation override for the given contexts or user contexts.
Args:
screen_orientation: ScreenOrientation object to emulate, or None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and userContexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or userContexts")
params: dict[str, Any] = {
"screenOrientation": screen_orientation.to_dict() if screen_orientation is not None else None
}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setScreenOrientationOverride", params))
def set_user_agent_override(
self,
user_agent: str | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set user agent override for the given contexts or user contexts.
Args:
user_agent: User agent string to emulate, or None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
params: dict[str, Any] = {"userAgent": user_agent}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setUserAgentOverride", params))
def set_network_conditions(
self,
offline: bool = False,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set network conditions for the given contexts or user contexts.
Args:
offline: True to emulate offline network conditions, False to clear the override.
contexts: List of browsing context IDs to apply the conditions to.
user_contexts: List of user context IDs to apply the conditions to.
Raises:
ValueError: If both contexts and user_contexts are provided, or if neither
contexts nor user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
params: dict[str, Any] = {}
if offline:
params["networkConditions"] = {"type": "offline"}
else:
# if offline is False or None, then clear the override
params["networkConditions"] = None
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setNetworkConditions", params))
def set_screen_settings_override(
self,
width: int | None = None,
height: int | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
) -> None:
"""Set screen settings override for the given contexts or user contexts.
Args:
width: Screen width in pixels (>= 0). None to clear the override.
height: Screen height in pixels (>= 0). None to clear the override.
contexts: List of browsing context IDs to apply the override to.
user_contexts: List of user context IDs to apply the override to.
Raises:
ValueError: If only one of width/height is provided, or if both contexts
and user_contexts are provided, or if neither is provided.
"""
if (width is None) != (height is None):
raise ValueError("Must provide both width and height, or neither to clear the override")
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
if contexts is None and user_contexts is None:
raise ValueError("Must specify either contexts or user_contexts")
screen_area = None
if width is not None and height is not None:
if not isinstance(width, int) or not isinstance(height, int):
raise ValueError("width and height must be integers")
if width < 0 or height < 0:
raise ValueError("width and height must be >= 0")
screen_area = {"width": width, "height": height}
params: dict[str, Any] = {"screenArea": screen_area}
if contexts is not None:
params["contexts"] = contexts
elif user_contexts is not None:
params["userContexts"] = user_contexts
self.conn.execute(command_builder("emulation.setScreenSettingsOverride", params))
@@ -0,0 +1,462 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
from dataclasses import dataclass, field
from typing import Any
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.common.bidi.session import Session
class PointerType:
"""Represents the possible pointer types."""
MOUSE = "mouse"
PEN = "pen"
TOUCH = "touch"
VALID_TYPES = {MOUSE, PEN, TOUCH}
class Origin:
"""Represents the possible origin types."""
VIEWPORT = "viewport"
POINTER = "pointer"
@dataclass
class ElementOrigin:
"""Represents an element origin for input actions."""
type: str
element: dict
def __init__(self, element_reference: dict):
self.type = "element"
self.element = element_reference
def to_dict(self) -> dict:
"""Convert the ElementOrigin to a dictionary."""
return {"type": self.type, "element": self.element}
@dataclass
class PointerParameters:
"""Represents pointer parameters for pointer actions."""
pointer_type: str = PointerType.MOUSE
def __post_init__(self):
if self.pointer_type not in PointerType.VALID_TYPES:
raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}")
def to_dict(self) -> dict:
"""Convert the PointerParameters to a dictionary."""
return {"pointerType": self.pointer_type}
@dataclass
class PointerCommonProperties:
"""Common properties for pointer actions."""
width: int = 1
height: int = 1
pressure: float = 0.0
tangential_pressure: float = 0.0
twist: int = 0
altitude_angle: float = 0.0
azimuth_angle: float = 0.0
def __post_init__(self):
if self.width < 1:
raise ValueError("width must be at least 1")
if self.height < 1:
raise ValueError("height must be at least 1")
if not (0.0 <= self.pressure <= 1.0):
raise ValueError("pressure must be between 0.0 and 1.0")
if not (0.0 <= self.tangential_pressure <= 1.0):
raise ValueError("tangential_pressure must be between 0.0 and 1.0")
if not (0 <= self.twist <= 359):
raise ValueError("twist must be between 0 and 359")
if not (0.0 <= self.altitude_angle <= math.pi / 2):
raise ValueError("altitude_angle must be between 0.0 and π/2")
if not (0.0 <= self.azimuth_angle <= 2 * math.pi):
raise ValueError("azimuth_angle must be between 0.0 and 2π")
def to_dict(self) -> dict:
"""Convert the PointerCommonProperties to a dictionary."""
result: dict[str, Any] = {}
if self.width != 1:
result["width"] = self.width
if self.height != 1:
result["height"] = self.height
if self.pressure != 0.0:
result["pressure"] = self.pressure
if self.tangential_pressure != 0.0:
result["tangentialPressure"] = self.tangential_pressure
if self.twist != 0:
result["twist"] = self.twist
if self.altitude_angle != 0.0:
result["altitudeAngle"] = self.altitude_angle
if self.azimuth_angle != 0.0:
result["azimuthAngle"] = self.azimuth_angle
return result
# Action classes
@dataclass
class PauseAction:
"""Represents a pause action."""
duration: int | None = None
@property
def type(self) -> str:
return "pause"
def to_dict(self) -> dict:
"""Convert the PauseAction to a dictionary."""
result: dict[str, Any] = {"type": self.type}
if self.duration is not None:
result["duration"] = self.duration
return result
@dataclass
class KeyDownAction:
"""Represents a key down action."""
value: str = ""
@property
def type(self) -> str:
return "keyDown"
def to_dict(self) -> dict:
"""Convert the KeyDownAction to a dictionary."""
return {"type": self.type, "value": self.value}
@dataclass
class KeyUpAction:
"""Represents a key up action."""
value: str = ""
@property
def type(self) -> str:
return "keyUp"
def to_dict(self) -> dict:
"""Convert the KeyUpAction to a dictionary."""
return {"type": self.type, "value": self.value}
@dataclass
class PointerDownAction:
"""Represents a pointer down action."""
button: int = 0
properties: PointerCommonProperties | None = None
@property
def type(self) -> str:
return "pointerDown"
def to_dict(self) -> dict:
"""Convert the PointerDownAction to a dictionary."""
result: dict[str, Any] = {"type": self.type, "button": self.button}
if self.properties:
result.update(self.properties.to_dict())
return result
@dataclass
class PointerUpAction:
"""Represents a pointer up action."""
button: int = 0
@property
def type(self) -> str:
return "pointerUp"
def to_dict(self) -> dict:
"""Convert the PointerUpAction to a dictionary."""
return {"type": self.type, "button": self.button}
@dataclass
class PointerMoveAction:
"""Represents a pointer move action."""
x: float = 0
y: float = 0
duration: int | None = None
origin: str | ElementOrigin | None = None
properties: PointerCommonProperties | None = None
@property
def type(self) -> str:
return "pointerMove"
def to_dict(self) -> dict:
"""Convert the PointerMoveAction to a dictionary."""
result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y}
if self.duration is not None:
result["duration"] = self.duration
if self.origin is not None:
if isinstance(self.origin, ElementOrigin):
result["origin"] = self.origin.to_dict()
else:
result["origin"] = self.origin
if self.properties:
result.update(self.properties.to_dict())
return result
@dataclass
class WheelScrollAction:
"""Represents a wheel scroll action."""
x: int = 0
y: int = 0
delta_x: int = 0
delta_y: int = 0
duration: int | None = None
origin: str | ElementOrigin | None = Origin.VIEWPORT
@property
def type(self) -> str:
return "scroll"
def to_dict(self) -> dict:
"""Convert the WheelScrollAction to a dictionary."""
result: dict[str, Any] = {
"type": self.type,
"x": self.x,
"y": self.y,
"deltaX": self.delta_x,
"deltaY": self.delta_y,
}
if self.duration is not None:
result["duration"] = self.duration
if self.origin is not None:
if isinstance(self.origin, ElementOrigin):
result["origin"] = self.origin.to_dict()
else:
result["origin"] = self.origin
return result
# Source Actions
@dataclass
class NoneSourceActions:
"""Represents a sequence of none actions."""
id: str = ""
actions: list[PauseAction] = field(default_factory=list)
@property
def type(self) -> str:
return "none"
def to_dict(self) -> dict:
"""Convert the NoneSourceActions to a dictionary."""
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
@dataclass
class KeySourceActions:
"""Represents a sequence of key actions."""
id: str = ""
actions: list[PauseAction | KeyDownAction | KeyUpAction] = field(default_factory=list)
@property
def type(self) -> str:
return "key"
def to_dict(self) -> dict:
"""Convert the KeySourceActions to a dictionary."""
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
@dataclass
class PointerSourceActions:
"""Represents a sequence of pointer actions."""
id: str = ""
parameters: PointerParameters | None = None
actions: list[PauseAction | PointerDownAction | PointerUpAction | PointerMoveAction] = field(default_factory=list)
def __post_init__(self):
if self.parameters is None:
self.parameters = PointerParameters()
@property
def type(self) -> str:
return "pointer"
def to_dict(self) -> dict:
"""Convert the PointerSourceActions to a dictionary."""
result: dict[str, Any] = {
"type": self.type,
"id": self.id,
"actions": [action.to_dict() for action in self.actions],
}
if self.parameters:
result["parameters"] = self.parameters.to_dict()
return result
@dataclass
class WheelSourceActions:
"""Represents a sequence of wheel actions."""
id: str = ""
actions: list[PauseAction | WheelScrollAction] = field(default_factory=list)
@property
def type(self) -> str:
return "wheel"
def to_dict(self) -> dict:
"""Convert the WheelSourceActions to a dictionary."""
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
@dataclass
class FileDialogInfo:
"""Represents file dialog information from input.fileDialogOpened event."""
context: str
multiple: bool
element: dict | None = None
@classmethod
def from_dict(cls, data: dict) -> "FileDialogInfo":
"""Creates a FileDialogInfo instance from a dictionary.
Args:
data: A dictionary containing the file dialog information.
Returns:
FileDialogInfo: A new instance of FileDialogInfo.
"""
return cls(context=data["context"], multiple=data["multiple"], element=data.get("element"))
# Event Class
class FileDialogOpened:
"""Event class for input.fileDialogOpened event."""
event_class = "input.fileDialogOpened"
@classmethod
def from_json(cls, json):
"""Create FileDialogInfo from JSON data."""
return FileDialogInfo.from_dict(json)
class Input:
"""BiDi implementation of the input module."""
def __init__(self, conn):
self.conn = conn
self.subscriptions = {}
self.callbacks = {}
def perform_actions(
self,
context: str,
actions: list[NoneSourceActions | KeySourceActions | PointerSourceActions | WheelSourceActions],
) -> None:
"""Performs a sequence of user input actions.
Args:
context: The browsing context ID where actions should be performed.
actions: A list of source actions to perform.
"""
params = {"context": context, "actions": [action.to_dict() for action in actions]}
self.conn.execute(command_builder("input.performActions", params))
def release_actions(self, context: str) -> None:
"""Releases all input state for the given context.
Args:
context: The browsing context ID to release actions for.
"""
params = {"context": context}
self.conn.execute(command_builder("input.releaseActions", params))
def set_files(self, context: str, element: dict, files: list[str]) -> None:
"""Sets files for a file input element.
Args:
context: The browsing context ID.
element: The element reference (script.SharedReference).
files: A list of file paths to set.
"""
params = {"context": context, "element": element, "files": files}
self.conn.execute(command_builder("input.setFiles", params))
def add_file_dialog_handler(self, handler) -> int:
"""Add a handler for file dialog opened events.
Args:
handler: Callback function that takes a FileDialogInfo object.
Returns:
int: Callback ID for removing the handler later.
"""
# Subscribe to the event if not already subscribed
if FileDialogOpened.event_class not in self.subscriptions:
session = Session(self.conn)
self.conn.execute(session.subscribe(FileDialogOpened.event_class))
self.subscriptions[FileDialogOpened.event_class] = []
# Add callback - the callback receives the parsed FileDialogInfo directly
callback_id = self.conn.add_callback(FileDialogOpened, handler)
self.subscriptions[FileDialogOpened.event_class].append(callback_id)
self.callbacks[callback_id] = handler
return callback_id
def remove_file_dialog_handler(self, callback_id: int) -> None:
"""Remove a file dialog handler.
Args:
callback_id: The callback ID returned by add_file_dialog_handler.
"""
if callback_id in self.callbacks:
del self.callbacks[callback_id]
if FileDialogOpened.event_class in self.subscriptions:
if callback_id in self.subscriptions[FileDialogOpened.event_class]:
self.subscriptions[FileDialogOpened.event_class].remove(callback_id)
# If no more callbacks for this event, unsubscribe
if not self.subscriptions[FileDialogOpened.event_class]:
session = Session(self.conn)
self.conn.execute(session.unsubscribe(FileDialogOpened.event_class))
del self.subscriptions[FileDialogOpened.event_class]
self.conn.remove_callback(FileDialogOpened, callback_id)
@@ -0,0 +1,81 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
class LogEntryAdded:
event_class = "log.entryAdded"
@classmethod
def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry | JavaScriptLogEntry | None:
if json["type"] == "console":
return ConsoleLogEntry.from_json(json)
elif json["type"] == "javascript":
return JavaScriptLogEntry.from_json(json)
return None
@dataclass
class ConsoleLogEntry:
level: str
text: str
timestamp: str
method: str
args: list[dict[str, Any]]
type_: str
@classmethod
def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry:
return cls(
level=json["level"],
text=json["text"],
timestamp=json["timestamp"],
method=json["method"],
args=json["args"],
type_=json["type"],
)
@dataclass
class JavaScriptLogEntry:
level: str
text: str
timestamp: str
stacktrace: dict[str, Any]
type_: str
@classmethod
def from_json(cls, json: dict[str, Any]) -> JavaScriptLogEntry:
return cls(
level=json["level"],
text=json["text"],
timestamp=json["timestamp"],
stacktrace=json["stackTrace"],
type_=json["type"],
)
class LogLevel:
"""Represents log level."""
DEBUG = "debug"
INFO = "info"
WARN = "warn"
ERROR = "error"
@@ -0,0 +1,338 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
class NetworkEvent:
"""Represents a network event."""
def __init__(self, event_class: str, **kwargs: Any) -> None:
self.event_class = event_class
self.params = kwargs
@classmethod
def from_json(cls, json: dict[str, Any]) -> NetworkEvent:
return cls(event_class=json.get("event_class", ""), **json)
class Network:
EVENTS = {
"before_request": "network.beforeRequestSent",
"response_started": "network.responseStarted",
"response_completed": "network.responseCompleted",
"auth_required": "network.authRequired",
"fetch_error": "network.fetchError",
"continue_request": "network.continueRequest",
"continue_auth": "network.continueWithAuth",
}
PHASES = {
"before_request": "beforeRequestSent",
"response_started": "responseStarted",
"auth_required": "authRequired",
}
def __init__(self, conn: WebSocketConnection) -> None:
self.conn = conn
self.intercepts: list[str] = []
self.callbacks: dict[str | int, Any] = {}
self.subscriptions: dict[str, list[int]] = {}
def _add_intercept(
self,
phases: list[str] | None = None,
contexts: list[str] | None = None,
url_patterns: list[Any] | None = None,
) -> dict[str, Any]:
"""Add an intercept to the network.
Args:
phases: A list of phases to intercept. Default is None (empty list).
contexts: A list of contexts to intercept. Default is None.
url_patterns: A list of URL patterns to intercept. Default is None.
Returns:
str: intercept id
"""
if phases is None:
phases = []
params = {}
if contexts is not None:
params["contexts"] = contexts
if url_patterns is not None:
params["urlPatterns"] = url_patterns
if len(phases) > 0:
params["phases"] = phases
else:
params["phases"] = ["beforeRequestSent"]
cmd = command_builder("network.addIntercept", params)
result: dict[str, Any] = self.conn.execute(cmd)
self.intercepts.append(result["intercept"])
return result
def _remove_intercept(self, intercept: str | None = None) -> None:
"""Remove a specific intercept, or all intercepts.
Args:
intercept: The intercept to remove. Default is None.
Raises:
ValueError: If intercept is not found.
Note:
If intercept is None, all intercepts will be removed.
"""
if intercept is None:
intercepts_to_remove = self.intercepts.copy() # create a copy before iterating
for intercept_id in intercepts_to_remove: # remove all intercepts
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id}))
self.intercepts.remove(intercept_id)
else:
try:
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept}))
self.intercepts.remove(intercept)
except Exception as e:
raise Exception(f"Exception: {e}")
def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int:
"""Set a callback function to subscribe to a network event.
Args:
event_name: The event to subscribe to.
callback: The callback function to execute on event.
Takes Request object as argument.
Returns:
int: callback id
"""
event = NetworkEvent(event_name)
def _callback(event_data: NetworkEvent) -> None:
request = Request(
network=self,
request_id=event_data.params["request"].get("request", None),
body_size=event_data.params["request"].get("bodySize", None),
cookies=event_data.params["request"].get("cookies", None),
resource_type=event_data.params["request"].get("goog:resourceType", None),
headers=event_data.params["request"].get("headers", None),
headers_size=event_data.params["request"].get("headersSize", None),
timings=event_data.params["request"].get("timings", None),
url=event_data.params["request"].get("url", None),
)
callback(request)
callback_id: int = self.conn.add_callback(event, _callback)
if event_name in self.callbacks:
self.callbacks[event_name].append(callback_id)
else:
self.callbacks[event_name] = [callback_id]
return callback_id
def add_request_handler(
self,
event: str,
callback: Callable[[Request], Any],
url_patterns: list[Any] | None = None,
contexts: list[str] | None = None,
) -> int:
"""Add a request handler to the network.
Args:
event: The event to subscribe to.
callback: The callback function to execute on request interception.
Takes Request object as argument.
url_patterns: A list of URL patterns to intercept. Default is None.
contexts: A list of contexts to intercept. Default is None.
Returns:
int: callback id
"""
try:
event_name = self.EVENTS[event]
phase_name = self.PHASES[event]
except KeyError:
raise Exception(f"Event {event} not found")
result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts)
callback_id = self._on_request(event_name, callback)
if event_name in self.subscriptions:
self.subscriptions[event_name].append(callback_id)
else:
params: dict[str, Any] = {}
params["events"] = [event_name]
self.conn.execute(command_builder("session.subscribe", params))
self.subscriptions[event_name] = [callback_id]
self.callbacks[callback_id] = result["intercept"]
return callback_id
def remove_request_handler(self, event: str, callback_id: int) -> None:
"""Remove a request handler from the network.
Args:
event: The event to unsubscribe from.
callback_id: The callback id to remove.
"""
try:
event_name = self.EVENTS[event]
except KeyError:
raise Exception(f"Event {event} not found")
net_event = NetworkEvent(event_name)
self.conn.remove_callback(net_event, callback_id)
self._remove_intercept(self.callbacks[callback_id])
del self.callbacks[callback_id]
self.subscriptions[event_name].remove(callback_id)
if len(self.subscriptions[event_name]) == 0:
params: dict[str, Any] = {}
params["events"] = [event_name]
self.conn.execute(command_builder("session.unsubscribe", params))
del self.subscriptions[event_name]
def clear_request_handlers(self) -> None:
"""Clear all request handlers from the network."""
for event_name in self.subscriptions:
net_event = NetworkEvent(event_name)
for callback_id in self.subscriptions[event_name]:
self.conn.remove_callback(net_event, callback_id)
self._remove_intercept(self.callbacks[callback_id])
del self.callbacks[callback_id]
params: dict[str, Any] = {}
params["events"] = [event_name]
self.conn.execute(command_builder("session.unsubscribe", params))
self.subscriptions = {}
def add_auth_handler(self, username: str, password: str) -> int:
"""Add an authentication handler to the network.
Args:
username: The username to authenticate with.
password: The password to authenticate with.
Returns:
int: callback id
"""
event = "auth_required"
def _callback(request: Request) -> None:
request._continue_with_auth(username, password)
return self.add_request_handler(event, _callback)
def remove_auth_handler(self, callback_id: int) -> None:
"""Remove an authentication handler from the network.
Args:
callback_id: The callback id to remove.
"""
event = "auth_required"
self.remove_request_handler(event, callback_id)
class Request:
"""Represents an intercepted network request."""
def __init__(
self,
network: Network,
request_id: Any,
body_size: int | None = None,
cookies: Any = None,
resource_type: str | None = None,
headers: Any = None,
headers_size: int | None = None,
method: str | None = None,
timings: Any = None,
url: str | None = None,
) -> None:
self.network = network
self.request_id = request_id
self.body_size = body_size
self.cookies = cookies
self.resource_type = resource_type
self.headers = headers
self.headers_size = headers_size
self.method = method
self.timings = timings
self.url = url
def fail_request(self) -> None:
"""Fail this request."""
if not self.request_id:
raise ValueError("Request not found.")
params: dict[str, Any] = {"request": self.request_id}
self.network.conn.execute(command_builder("network.failRequest", params))
def continue_request(
self,
body: Any = None,
method: str | None = None,
headers: Any = None,
cookies: Any = None,
url: str | None = None,
) -> None:
"""Continue after intercepting this request."""
if not self.request_id:
raise ValueError("Request not found.")
params: dict[str, Any] = {"request": self.request_id}
if body is not None:
params["body"] = body
if method is not None:
params["method"] = method
if headers is not None:
params["headers"] = headers
if cookies is not None:
params["cookies"] = cookies
if url is not None:
params["url"] = url
self.network.conn.execute(command_builder("network.continueRequest", params))
def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None:
"""Continue with authentication.
Args:
username: The username to authenticate with.
password: The password to authenticate with.
Note:
If username or password is None, it attempts auth with no credentials.
"""
params: dict[str, Any] = {}
params["request"] = self.request_id
if not username or not password: # no credentials is valid option
params["action"] = "default"
else:
params["action"] = "provideCredentials"
params["credentials"] = {"type": "password", "username": username, "password": password}
self.network.conn.execute(command_builder("network.continueWithAuth", params))
@@ -0,0 +1,83 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.bidi.common import command_builder
class PermissionState:
"""Represents the possible permission states."""
GRANTED = "granted"
DENIED = "denied"
PROMPT = "prompt"
class PermissionDescriptor:
"""Represents a permission descriptor."""
def __init__(self, name: str):
self.name = name
def to_dict(self) -> dict:
return {"name": self.name}
class Permissions:
"""BiDi implementation of the permissions module."""
def __init__(self, conn):
self.conn = conn
def set_permission(
self,
descriptor: str | PermissionDescriptor,
state: str,
origin: str,
user_context: str | None = None,
) -> None:
"""Sets a permission state for a given permission descriptor.
Args:
descriptor: The permission name (str) or PermissionDescriptor object.
Examples: "geolocation", "camera", "microphone".
state: The permission state (granted, denied, prompt).
origin: The origin for which the permission is set.
user_context: The user context id (optional).
Raises:
ValueError: If the permission state is invalid.
"""
if state not in [PermissionState.GRANTED, PermissionState.DENIED, PermissionState.PROMPT]:
valid_states = f"{PermissionState.GRANTED}, {PermissionState.DENIED}, {PermissionState.PROMPT}"
raise ValueError(f"Invalid permission state. Must be one of: {valid_states}")
if isinstance(descriptor, str):
permission_descriptor = PermissionDescriptor(descriptor)
else:
permission_descriptor = descriptor
params = {
"descriptor": permission_descriptor.to_dict(),
"state": state,
"origin": origin,
}
if user_context is not None:
params["userContext"] = user_context
self.conn.execute(command_builder("permissions.setPermission", params))
@@ -0,0 +1,547 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import math
from dataclasses import dataclass
from typing import Any
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.common.bidi.log import LogEntryAdded
from selenium.webdriver.common.bidi.session import Session
class ResultOwnership:
"""Represents the possible result ownership types."""
NONE = "none"
ROOT = "root"
class RealmType:
"""Represents the possible realm types."""
WINDOW = "window"
DEDICATED_WORKER = "dedicated-worker"
SHARED_WORKER = "shared-worker"
SERVICE_WORKER = "service-worker"
WORKER = "worker"
PAINT_WORKLET = "paint-worklet"
AUDIO_WORKLET = "audio-worklet"
WORKLET = "worklet"
@dataclass
class RealmInfo:
"""Represents information about a realm."""
realm: str
origin: str
type: str
context: str | None = None
sandbox: str | None = None
@classmethod
def from_json(cls, json: dict[str, Any]) -> "RealmInfo":
"""Creates a RealmInfo instance from a dictionary.
Args:
json: A dictionary containing the realm information.
Returns:
RealmInfo: A new instance of RealmInfo.
"""
if "realm" not in json:
raise ValueError("Missing required field 'realm' in RealmInfo")
if "origin" not in json:
raise ValueError("Missing required field 'origin' in RealmInfo")
if "type" not in json:
raise ValueError("Missing required field 'type' in RealmInfo")
return cls(
realm=json["realm"],
origin=json["origin"],
type=json["type"],
context=json.get("context"),
sandbox=json.get("sandbox"),
)
@dataclass
class Source:
"""Represents the source of a script message."""
realm: str
context: str | None = None
@classmethod
def from_json(cls, json: dict[str, Any]) -> "Source":
"""Creates a Source instance from a dictionary.
Args:
json: A dictionary containing the source information.
Returns:
Source: A new instance of Source.
"""
if "realm" not in json:
raise ValueError("Missing required field 'realm' in Source")
return cls(
realm=json["realm"],
context=json.get("context"),
)
@dataclass
class EvaluateResult:
"""Represents the result of script evaluation."""
type: str
realm: str
result: dict | None = None
exception_details: dict | None = None
@classmethod
def from_json(cls, json: dict[str, Any]) -> "EvaluateResult":
"""Creates an EvaluateResult instance from a dictionary.
Args:
json: A dictionary containing the evaluation result.
Returns:
EvaluateResult: A new instance of EvaluateResult.
"""
if "realm" not in json:
raise ValueError("Missing required field 'realm' in EvaluateResult")
if "type" not in json:
raise ValueError("Missing required field 'type' in EvaluateResult")
return cls(
type=json["type"],
realm=json["realm"],
result=json.get("result"),
exception_details=json.get("exceptionDetails"),
)
class ScriptMessage:
"""Represents a script message event."""
event_class = "script.message"
def __init__(self, channel: str, data: dict, source: Source):
self.channel = channel
self.data = data
self.source = source
@classmethod
def from_json(cls, json: dict[str, Any]) -> "ScriptMessage":
"""Creates a ScriptMessage instance from a dictionary.
Args:
json: A dictionary containing the script message.
Returns:
ScriptMessage: A new instance of ScriptMessage.
"""
if "channel" not in json:
raise ValueError("Missing required field 'channel' in ScriptMessage")
if "data" not in json:
raise ValueError("Missing required field 'data' in ScriptMessage")
if "source" not in json:
raise ValueError("Missing required field 'source' in ScriptMessage")
return cls(
channel=json["channel"],
data=json["data"],
source=Source.from_json(json["source"]),
)
class RealmCreated:
"""Represents a realm created event."""
event_class = "script.realmCreated"
def __init__(self, realm_info: RealmInfo):
self.realm_info = realm_info
@classmethod
def from_json(cls, json: dict[str, Any]) -> "RealmCreated":
"""Creates a RealmCreated instance from a dictionary.
Args:
json: A dictionary containing the realm created event.
Returns:
RealmCreated: A new instance of RealmCreated.
"""
return cls(realm_info=RealmInfo.from_json(json))
class RealmDestroyed:
"""Represents a realm destroyed event."""
event_class = "script.realmDestroyed"
def __init__(self, realm: str):
self.realm = realm
@classmethod
def from_json(cls, json: dict[str, Any]) -> "RealmDestroyed":
"""Creates a RealmDestroyed instance from a dictionary.
Args:
json: A dictionary containing the realm destroyed event.
Returns:
RealmDestroyed: A new instance of RealmDestroyed.
"""
if "realm" not in json:
raise ValueError("Missing required field 'realm' in RealmDestroyed")
return cls(realm=json["realm"])
class Script:
"""BiDi implementation of the script module."""
EVENTS = {
"message": "script.message",
"realm_created": "script.realmCreated",
"realm_destroyed": "script.realmDestroyed",
}
def __init__(self, conn, driver=None):
self.conn = conn
self.driver = driver
self.log_entry_subscribed = False
self.subscriptions = {}
self.callbacks = {}
# High-level APIs for SCRIPT module
def add_console_message_handler(self, handler):
self._subscribe_to_log_entries()
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler))
def add_javascript_error_handler(self, handler):
self._subscribe_to_log_entries()
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler))
def remove_console_message_handler(self, id):
self.conn.remove_callback(LogEntryAdded, id)
self._unsubscribe_from_log_entries()
remove_javascript_error_handler = remove_console_message_handler
def pin(self, script: str) -> str:
"""Pins a script to the current browsing context.
Args:
script: The script to pin.
Returns:
str: The ID of the pinned script.
"""
return self._add_preload_script(script)
def unpin(self, script_id: str) -> None:
"""Unpins a script from the current browsing context.
Args:
script_id: The ID of the pinned script to unpin.
"""
self._remove_preload_script(script_id)
def execute(self, script: str, *args) -> dict:
"""Executes a script in the current browsing context.
Args:
script: The script function to execute.
*args: Arguments to pass to the script function.
Returns:
dict: The result value from the script execution.
Raises:
WebDriverException: If the script execution fails.
"""
if self.driver is None:
raise WebDriverException("Driver reference is required for script execution")
browsing_context_id = self.driver.current_window_handle
# Convert arguments to the format expected by BiDi call_function (LocalValue Type)
arguments = []
for arg in args:
arguments.append(self.__convert_to_local_value(arg))
target = {"context": browsing_context_id}
result = self._call_function(
function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None
)
if result.type == "success":
return result.result if result.result is not None else {}
else:
error_message = "Error while executing script"
if result.exception_details:
if "text" in result.exception_details:
error_message += f": {result.exception_details['text']}"
elif "message" in result.exception_details:
error_message += f": {result.exception_details['message']}"
raise WebDriverException(error_message)
def __convert_to_local_value(self, value) -> dict:
"""Converts a Python value to BiDi LocalValue format."""
if value is None:
return {"type": "null"}
elif isinstance(value, bool):
return {"type": "boolean", "value": value}
elif isinstance(value, (int, float)):
if isinstance(value, float):
if math.isnan(value):
return {"type": "number", "value": "NaN"}
elif math.isinf(value):
if value > 0:
return {"type": "number", "value": "Infinity"}
else:
return {"type": "number", "value": "-Infinity"}
elif value == 0.0 and math.copysign(1.0, value) < 0:
return {"type": "number", "value": "-0"}
JS_MAX_SAFE_INTEGER = 9007199254740991
if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER):
return {"type": "bigint", "value": str(value)}
return {"type": "number", "value": value}
elif isinstance(value, str):
return {"type": "string", "value": value}
elif isinstance(value, datetime.datetime):
# Convert Python datetime to JavaScript Date (ISO 8601 format)
return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()}
elif isinstance(value, datetime.date):
# Convert Python date to JavaScript Date
dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc)
return {"type": "date", "value": dt.isoformat()}
elif isinstance(value, set):
return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]}
elif isinstance(value, (list, tuple)):
return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]}
elif isinstance(value, dict):
return {
"type": "object",
"value": [
[self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items()
],
}
else:
# For other types, convert to string
return {"type": "string", "value": str(value)}
# low-level APIs for script module
def _add_preload_script(
self,
function_declaration: str,
arguments: list[dict[str, Any]] | None = None,
contexts: list[str] | None = None,
user_contexts: list[str] | None = None,
sandbox: str | None = None,
) -> str:
"""Adds a preload script.
Args:
function_declaration: The function declaration to preload.
arguments: The arguments to pass to the function.
contexts: The browsing context IDs to apply the script to.
user_contexts: The user context IDs to apply the script to.
sandbox: The sandbox name to apply the script to.
Returns:
str: The preload script ID.
Raises:
ValueError: If both contexts and user_contexts are provided.
"""
if contexts is not None and user_contexts is not None:
raise ValueError("Cannot specify both contexts and user_contexts")
params: dict[str, Any] = {"functionDeclaration": function_declaration}
if arguments is not None:
params["arguments"] = arguments
if contexts is not None:
params["contexts"] = contexts
if user_contexts is not None:
params["userContexts"] = user_contexts
if sandbox is not None:
params["sandbox"] = sandbox
result = self.conn.execute(command_builder("script.addPreloadScript", params))
return result["script"]
def _remove_preload_script(self, script_id: str) -> None:
"""Removes a preload script.
Args:
script_id: The preload script ID to remove.
"""
params = {"script": script_id}
self.conn.execute(command_builder("script.removePreloadScript", params))
def _disown(self, handles: list[str], target: dict) -> None:
"""Disowns the given handles.
Args:
handles: The handles to disown.
target: The target realm or context.
"""
params = {
"handles": handles,
"target": target,
}
self.conn.execute(command_builder("script.disown", params))
def _call_function(
self,
function_declaration: str,
await_promise: bool,
target: dict,
arguments: list[dict] | None = None,
result_ownership: str | None = None,
serialization_options: dict | None = None,
this: dict | None = None,
user_activation: bool = False,
) -> EvaluateResult:
"""Calls a provided function with given arguments in a given realm.
Args:
function_declaration: The function declaration to call.
await_promise: Whether to await promise resolution.
target: The target realm or context.
arguments: The arguments to pass to the function.
result_ownership: The result ownership type.
serialization_options: The serialization options.
this: The 'this' value for the function call.
user_activation: Whether to trigger user activation.
Returns:
EvaluateResult: The result of the function call.
"""
params = {
"functionDeclaration": function_declaration,
"awaitPromise": await_promise,
"target": target,
"userActivation": user_activation,
}
if arguments is not None:
params["arguments"] = arguments
if result_ownership is not None:
params["resultOwnership"] = result_ownership
if serialization_options is not None:
params["serializationOptions"] = serialization_options
if this is not None:
params["this"] = this
result = self.conn.execute(command_builder("script.callFunction", params))
return EvaluateResult.from_json(result)
def _evaluate(
self,
expression: str,
target: dict,
await_promise: bool,
result_ownership: str | None = None,
serialization_options: dict | None = None,
user_activation: bool = False,
) -> EvaluateResult:
"""Evaluates a provided script in a given realm.
Args:
expression: The script expression to evaluate.
target: The target realm or context.
await_promise: Whether to await promise resolution.
result_ownership: The result ownership type.
serialization_options: The serialization options.
user_activation: Whether to trigger user activation.
Returns:
EvaluateResult: The result of the script evaluation.
"""
params = {
"expression": expression,
"target": target,
"awaitPromise": await_promise,
"userActivation": user_activation,
}
if result_ownership is not None:
params["resultOwnership"] = result_ownership
if serialization_options is not None:
params["serializationOptions"] = serialization_options
result = self.conn.execute(command_builder("script.evaluate", params))
return EvaluateResult.from_json(result)
def _get_realms(
self,
context: str | None = None,
type: str | None = None,
) -> list[RealmInfo]:
"""Returns a list of all realms, optionally filtered.
Args:
context: The browsing context ID to filter by.
type: The realm type to filter by.
Returns:
List[RealmInfo]: A list of realm information.
"""
params = {}
if context is not None:
params["context"] = context
if type is not None:
params["type"] = type
result = self.conn.execute(command_builder("script.getRealms", params))
return [RealmInfo.from_json(realm) for realm in result["realms"]]
def _subscribe_to_log_entries(self):
if not self.log_entry_subscribed:
session = Session(self.conn)
self.conn.execute(session.subscribe(LogEntryAdded.event_class))
self.log_entry_subscribed = True
def _unsubscribe_from_log_entries(self):
if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks:
session = Session(self.conn)
self.conn.execute(session.unsubscribe(LogEntryAdded.event_class))
self.log_entry_subscribed = False
def _handle_log_entry(self, type, handler):
def _handle_log_entry(log_entry):
if log_entry.type_ == type:
handler(log_entry)
return _handle_log_entry
@@ -0,0 +1,134 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.webdriver.common.bidi.common import command_builder
class UserPromptHandlerType:
"""Represents the behavior of the user prompt handler."""
ACCEPT = "accept"
DISMISS = "dismiss"
IGNORE = "ignore"
VALID_TYPES = {ACCEPT, DISMISS, IGNORE}
class UserPromptHandler:
"""Represents the configuration of the user prompt handler."""
def __init__(
self,
alert: str | None = None,
before_unload: str | None = None,
confirm: str | None = None,
default: str | None = None,
file: str | None = None,
prompt: str | None = None,
):
"""Initialize UserPromptHandler.
Args:
alert: Handler type for alert prompts.
before_unload: Handler type for beforeUnload prompts.
confirm: Handler type for confirm prompts.
default: Default handler type for all prompts.
file: Handler type for file picker prompts.
prompt: Handler type for prompt dialogs.
Raises:
ValueError: If any handler type is not valid.
"""
for field_name, value in [
("alert", alert),
("before_unload", before_unload),
("confirm", confirm),
("default", default),
("file", file),
("prompt", prompt),
]:
if value is not None and value not in UserPromptHandlerType.VALID_TYPES:
raise ValueError(
f"Invalid {field_name} handler type: {value}. Must be one of {UserPromptHandlerType.VALID_TYPES}"
)
self.alert = alert
self.before_unload = before_unload
self.confirm = confirm
self.default = default
self.file = file
self.prompt = prompt
def to_dict(self) -> dict[str, str]:
"""Convert the UserPromptHandler to a dictionary for BiDi protocol.
Returns:
Dictionary representation suitable for BiDi protocol.
"""
field_mapping = {
"alert": "alert",
"before_unload": "beforeUnload",
"confirm": "confirm",
"default": "default",
"file": "file",
"prompt": "prompt",
}
result = {}
for attr_name, dict_key in field_mapping.items():
value = getattr(self, attr_name)
if value is not None:
result[dict_key] = value
return result
class Session:
def __init__(self, conn):
self.conn = conn
def subscribe(self, *events, browsing_contexts=None):
params = {
"events": events,
}
if browsing_contexts is None:
browsing_contexts = []
if browsing_contexts:
params["browsingContexts"] = browsing_contexts
return command_builder("session.subscribe", params)
def unsubscribe(self, *events, browsing_contexts=None):
params = {
"events": events,
}
if browsing_contexts is None:
browsing_contexts = []
if browsing_contexts:
params["browsingContexts"] = browsing_contexts
return command_builder("session.unsubscribe", params)
def status(self):
"""The session.status command returns information about the remote end's readiness.
Returns information about the remote end's readiness to create new sessions
and may include implementation-specific metadata.
Returns:
Dictionary containing the ready state (bool), message (str) and metadata.
"""
cmd = command_builder("session.status", {})
return self.conn.execute(cmd)
@@ -0,0 +1,413 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from selenium.webdriver.common.bidi.common import command_builder
if TYPE_CHECKING:
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
class SameSite:
"""Represents the possible same site values for cookies."""
STRICT = "strict"
LAX = "lax"
NONE = "none"
DEFAULT = "default"
class BytesValue:
"""Represents a bytes value."""
TYPE_BASE64 = "base64"
TYPE_STRING = "string"
def __init__(self, type: str, value: str):
self.type = type
self.value = value
def to_dict(self) -> dict[str, str]:
"""Converts the BytesValue to a dictionary.
Returns:
A dictionary representation of the BytesValue.
"""
return {"type": self.type, "value": self.value}
class Cookie:
"""Represents a cookie."""
def __init__(
self,
name: str,
value: BytesValue,
domain: str,
path: str | None = None,
size: int | None = None,
http_only: bool | None = None,
secure: bool | None = None,
same_site: str | None = None,
expiry: int | None = None,
):
self.name = name
self.value = value
self.domain = domain
self.path = path
self.size = size
self.http_only = http_only
self.secure = secure
self.same_site = same_site
self.expiry = expiry
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Cookie:
"""Creates a Cookie instance from a dictionary.
Args:
data: A dictionary containing the cookie information.
Returns:
A new instance of Cookie.
"""
# Validation for empty strings
name = data.get("name")
if not name:
raise ValueError("name is required and cannot be empty")
domain = data.get("domain")
if not domain:
raise ValueError("domain is required and cannot be empty")
value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value"))
return cls(
name=str(name),
value=value,
domain=str(domain),
path=data.get("path"),
size=data.get("size"),
http_only=data.get("httpOnly"),
secure=data.get("secure"),
same_site=data.get("sameSite"),
expiry=data.get("expiry"),
)
class CookieFilter:
"""Represents a filter for cookies."""
def __init__(
self,
name: str | None = None,
value: BytesValue | None = None,
domain: str | None = None,
path: str | None = None,
size: int | None = None,
http_only: bool | None = None,
secure: bool | None = None,
same_site: str | None = None,
expiry: int | None = None,
):
self.name = name
self.value = value
self.domain = domain
self.path = path
self.size = size
self.http_only = http_only
self.secure = secure
self.same_site = same_site
self.expiry = expiry
def to_dict(self) -> dict[str, Any]:
"""Converts the CookieFilter to a dictionary.
Returns:
A dictionary representation of the CookieFilter.
"""
result: dict[str, Any] = {}
if self.name is not None:
result["name"] = self.name
if self.value is not None:
result["value"] = self.value.to_dict()
if self.domain is not None:
result["domain"] = self.domain
if self.path is not None:
result["path"] = self.path
if self.size is not None:
result["size"] = self.size
if self.http_only is not None:
result["httpOnly"] = self.http_only
if self.secure is not None:
result["secure"] = self.secure
if self.same_site is not None:
result["sameSite"] = self.same_site
if self.expiry is not None:
result["expiry"] = self.expiry
return result
class PartitionKey:
"""Represents a storage partition key."""
def __init__(self, user_context: str | None = None, source_origin: str | None = None):
self.user_context = user_context
self.source_origin = source_origin
@classmethod
def from_dict(cls, data: dict[str, Any]) -> PartitionKey:
"""Creates a PartitionKey instance from a dictionary.
Args:
data: A dictionary containing the partition key information.
Returns:
A new instance of PartitionKey.
"""
return cls(
user_context=data.get("userContext"),
source_origin=data.get("sourceOrigin"),
)
class BrowsingContextPartitionDescriptor:
"""Represents a browsing context partition descriptor."""
def __init__(self, context: str):
self.type = "context"
self.context = context
def to_dict(self) -> dict[str, str]:
"""Converts the BrowsingContextPartitionDescriptor to a dictionary.
Returns:
Dict: A dictionary representation of the BrowsingContextPartitionDescriptor.
"""
return {"type": self.type, "context": self.context}
class StorageKeyPartitionDescriptor:
"""Represents a storage key partition descriptor."""
def __init__(self, user_context: str | None = None, source_origin: str | None = None):
self.type = "storageKey"
self.user_context = user_context
self.source_origin = source_origin
def to_dict(self) -> dict[str, str]:
"""Converts the StorageKeyPartitionDescriptor to a dictionary.
Returns:
Dict: A dictionary representation of the StorageKeyPartitionDescriptor.
"""
result = {"type": self.type}
if self.user_context is not None:
result["userContext"] = self.user_context
if self.source_origin is not None:
result["sourceOrigin"] = self.source_origin
return result
class PartialCookie:
"""Represents a partial cookie for setting."""
def __init__(
self,
name: str,
value: BytesValue,
domain: str,
path: str | None = None,
http_only: bool | None = None,
secure: bool | None = None,
same_site: str | None = None,
expiry: int | None = None,
):
self.name = name
self.value = value
self.domain = domain
self.path = path
self.http_only = http_only
self.secure = secure
self.same_site = same_site
self.expiry = expiry
def to_dict(self) -> dict[str, Any]:
"""Converts the PartialCookie to a dictionary.
Returns:
-------
Dict: A dictionary representation of the PartialCookie.
"""
result: dict[str, Any] = {
"name": self.name,
"value": self.value.to_dict(),
"domain": self.domain,
}
if self.path is not None:
result["path"] = self.path
if self.http_only is not None:
result["httpOnly"] = self.http_only
if self.secure is not None:
result["secure"] = self.secure
if self.same_site is not None:
result["sameSite"] = self.same_site
if self.expiry is not None:
result["expiry"] = self.expiry
return result
class GetCookiesResult:
"""Represents the result of a getCookies command."""
def __init__(self, cookies: list[Cookie], partition_key: PartitionKey):
self.cookies = cookies
self.partition_key = partition_key
@classmethod
def from_dict(cls, data: dict[str, Any]) -> GetCookiesResult:
"""Creates a GetCookiesResult instance from a dictionary.
Args:
data: A dictionary containing the get cookies result information.
Returns:
A new instance of GetCookiesResult.
"""
cookies = [Cookie.from_dict(cookie) for cookie in data.get("cookies", [])]
partition_key = PartitionKey.from_dict(data.get("partitionKey", {}))
return cls(cookies=cookies, partition_key=partition_key)
class SetCookieResult:
"""Represents the result of a setCookie command."""
def __init__(self, partition_key: PartitionKey):
self.partition_key = partition_key
@classmethod
def from_dict(cls, data: dict[str, Any]) -> SetCookieResult:
"""Creates a SetCookieResult instance from a dictionary.
Args:
data: A dictionary containing the set cookie result information.
Returns:
A new instance of SetCookieResult.
"""
partition_key = PartitionKey.from_dict(data.get("partitionKey", {}))
return cls(partition_key=partition_key)
class DeleteCookiesResult:
"""Represents the result of a deleteCookies command."""
def __init__(self, partition_key: PartitionKey):
self.partition_key = partition_key
@classmethod
def from_dict(cls, data: dict[str, Any]) -> DeleteCookiesResult:
"""Creates a DeleteCookiesResult instance from a dictionary.
Args:
data: A dictionary containing the delete cookies result information.
Returns:
A new instance of DeleteCookiesResult.
"""
partition_key = PartitionKey.from_dict(data.get("partitionKey", {}))
return cls(partition_key=partition_key)
class Storage:
"""BiDi implementation of the storage module."""
def __init__(self, conn: WebSocketConnection) -> None:
self.conn = conn
def get_cookies(
self,
filter: CookieFilter | None = None,
partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None,
) -> GetCookiesResult:
"""Gets cookies matching the specified filter.
Args:
filter: Optional filter to specify which cookies to retrieve.
partition: Optional partition key to limit the scope of the operation.
Returns:
A GetCookiesResult containing the cookies and partition key.
Example:
result = await storage.get_cookies(
filter=CookieFilter(name="sessionId"),
partition=PartitionKey(...)
)
"""
params = {}
if filter is not None:
params["filter"] = filter.to_dict()
if partition is not None:
params["partition"] = partition.to_dict()
result = self.conn.execute(command_builder("storage.getCookies", params))
return GetCookiesResult.from_dict(result)
def set_cookie(
self,
cookie: PartialCookie,
partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None,
) -> SetCookieResult:
"""Sets a cookie in the browser.
Args:
cookie: The cookie to set.
partition: Optional partition descriptor.
Returns:
The result of the set cookie command.
"""
params = {"cookie": cookie.to_dict()}
if partition is not None:
params["partition"] = partition.to_dict()
result = self.conn.execute(command_builder("storage.setCookie", params))
return SetCookieResult.from_dict(result)
def delete_cookies(
self,
filter: CookieFilter | None = None,
partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None,
) -> DeleteCookiesResult:
"""Deletes cookies that match the given parameters.
Args:
filter: Optional filter to match cookies to delete.
partition: Optional partition descriptor.
Returns:
The result of the delete cookies command.
"""
params = {}
if filter is not None:
params["filter"] = filter.to_dict()
if partition is not None:
params["partition"] = partition.to_dict()
result = self.conn.execute(command_builder("storage.deleteCookies", params))
return DeleteCookiesResult.from_dict(result)
@@ -0,0 +1,78 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.common.bidi.common import command_builder
class WebExtension:
"""BiDi implementation of the webExtension module."""
def __init__(self, conn):
self.conn = conn
def install(self, path=None, archive_path=None, base64_value=None) -> dict:
"""Installs a web extension in the remote end.
You must provide exactly one of the parameters.
Args:
path: Path to an extension directory.
archive_path: Path to an extension archive file.
base64_value: Base64 encoded string of the extension archive.
Returns:
A dictionary containing the extension ID.
"""
if sum(x is not None for x in (path, archive_path, base64_value)) != 1:
raise ValueError("Exactly one of path, archive_path, or base64_value must be provided")
if path is not None:
extension_data = {"type": "path", "path": path}
elif archive_path is not None:
extension_data = {"type": "archivePath", "path": archive_path}
elif base64_value is not None:
extension_data = {"type": "base64", "value": base64_value}
params = {"extensionData": extension_data}
try:
result = self.conn.execute(command_builder("webExtension.install", params))
return result
except WebDriverException as e:
if "Method not available" in str(e):
raise WebDriverException(
f"{e!s}. If you are using Chrome or Edge, add '--enable-unsafe-extension-debugging' "
"and '--remote-debugging-pipe' arguments or set options.enable_webextensions = True"
) from e
raise
def uninstall(self, extension_id_or_result: str | dict) -> None:
"""Uninstalls a web extension from the remote end.
Args:
extension_id_or_result: Either the extension ID as a string or the result dictionary
from a previous install() call containing the extension ID.
"""
if isinstance(extension_id_or_result, dict):
extension_id = extension_id_or_result.get("extension")
else:
extension_id = extension_id_or_result
params = {"extension": extension_id}
self.conn.execute(command_builder("webExtension.uninstall", params))