initial commit
This commit is contained in:
+16
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+280
@@ -0,0 +1,280 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
from selenium.webdriver.common.bidi.session import UserPromptHandler
|
||||
from selenium.webdriver.common.proxy import Proxy
|
||||
|
||||
|
||||
class ClientWindowState:
|
||||
"""Represents a window state."""
|
||||
|
||||
FULLSCREEN = "fullscreen"
|
||||
MAXIMIZED = "maximized"
|
||||
MINIMIZED = "minimized"
|
||||
NORMAL = "normal"
|
||||
|
||||
VALID_STATES = {FULLSCREEN, MAXIMIZED, MINIMIZED, NORMAL}
|
||||
|
||||
|
||||
class ClientWindowInfo:
|
||||
"""Represents a client window information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_window: str,
|
||||
state: str,
|
||||
width: int,
|
||||
height: int,
|
||||
x: int,
|
||||
y: int,
|
||||
active: bool,
|
||||
):
|
||||
self.client_window = client_window
|
||||
self.state = state
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.active = active
|
||||
|
||||
def get_state(self) -> str:
|
||||
"""Gets the state of the client window.
|
||||
|
||||
Returns:
|
||||
str: The state of the client window (one of the ClientWindowState constants).
|
||||
"""
|
||||
return self.state
|
||||
|
||||
def get_client_window(self) -> str:
|
||||
"""Gets the client window identifier.
|
||||
|
||||
Returns:
|
||||
str: The client window identifier.
|
||||
"""
|
||||
return self.client_window
|
||||
|
||||
def get_width(self) -> int:
|
||||
"""Gets the width of the client window.
|
||||
|
||||
Returns:
|
||||
int: The width of the client window.
|
||||
"""
|
||||
return self.width
|
||||
|
||||
def get_height(self) -> int:
|
||||
"""Gets the height of the client window.
|
||||
|
||||
Returns:
|
||||
int: The height of the client window.
|
||||
"""
|
||||
return self.height
|
||||
|
||||
def get_x(self) -> int:
|
||||
"""Gets the x coordinate of the client window.
|
||||
|
||||
Returns:
|
||||
int: The x coordinate of the client window.
|
||||
"""
|
||||
return self.x
|
||||
|
||||
def get_y(self) -> int:
|
||||
"""Gets the y coordinate of the client window.
|
||||
|
||||
Returns:
|
||||
int: The y coordinate of the client window.
|
||||
"""
|
||||
return self.y
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Checks if the client window is active.
|
||||
|
||||
Returns:
|
||||
bool: True if the client window is active, False otherwise.
|
||||
"""
|
||||
return self.active
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ClientWindowInfo":
|
||||
"""Creates a ClientWindowInfo instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the client window information.
|
||||
|
||||
Returns:
|
||||
ClientWindowInfo: A new instance of ClientWindowInfo.
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or have invalid types.
|
||||
"""
|
||||
try:
|
||||
client_window = data["clientWindow"]
|
||||
if not isinstance(client_window, str):
|
||||
raise ValueError("clientWindow must be a string")
|
||||
|
||||
state = data["state"]
|
||||
if not isinstance(state, str):
|
||||
raise ValueError("state must be a string")
|
||||
if state not in ClientWindowState.VALID_STATES:
|
||||
raise ValueError(f"Invalid state: {state}. Must be one of {ClientWindowState.VALID_STATES}")
|
||||
|
||||
width = data["width"]
|
||||
if not isinstance(width, int) or width < 0:
|
||||
raise ValueError(f"width must be a non-negative integer, got {width}")
|
||||
|
||||
height = data["height"]
|
||||
if not isinstance(height, int) or height < 0:
|
||||
raise ValueError(f"height must be a non-negative integer, got {height}")
|
||||
|
||||
x = data["x"]
|
||||
if not isinstance(x, int):
|
||||
raise ValueError(f"x must be an integer, got {type(x).__name__}")
|
||||
|
||||
y = data["y"]
|
||||
if not isinstance(y, int):
|
||||
raise ValueError(f"y must be an integer, got {type(y).__name__}")
|
||||
|
||||
active = data["active"]
|
||||
if not isinstance(active, bool):
|
||||
raise ValueError("active must be a boolean")
|
||||
|
||||
return cls(
|
||||
client_window=client_window,
|
||||
state=state,
|
||||
width=width,
|
||||
height=height,
|
||||
x=x,
|
||||
y=y,
|
||||
active=active,
|
||||
)
|
||||
except (KeyError, TypeError) as e:
|
||||
raise ValueError(f"Invalid data format for ClientWindowInfo: {e}") from e
|
||||
|
||||
|
||||
class Browser:
|
||||
"""BiDi implementation of the browser module."""
|
||||
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
def create_user_context(
|
||||
self,
|
||||
accept_insecure_certs: bool | None = None,
|
||||
proxy: Proxy | None = None,
|
||||
unhandled_prompt_behavior: UserPromptHandler | None = None,
|
||||
) -> str:
|
||||
"""Creates a new user context.
|
||||
|
||||
Args:
|
||||
accept_insecure_certs: Optional flag to accept insecure TLS certificates.
|
||||
proxy: Optional proxy configuration for the user context.
|
||||
unhandled_prompt_behavior: Optional configuration for handling user prompts.
|
||||
|
||||
Returns:
|
||||
str: The ID of the created user context.
|
||||
"""
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if accept_insecure_certs is not None:
|
||||
params["acceptInsecureCerts"] = accept_insecure_certs
|
||||
|
||||
if proxy is not None:
|
||||
params["proxy"] = proxy.to_bidi_dict()
|
||||
|
||||
if unhandled_prompt_behavior is not None:
|
||||
params["unhandledPromptBehavior"] = unhandled_prompt_behavior.to_dict()
|
||||
|
||||
result = self.conn.execute(command_builder("browser.createUserContext", params))
|
||||
return result["userContext"]
|
||||
|
||||
def get_user_contexts(self) -> list[str]:
|
||||
"""Gets all user contexts.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of user context IDs.
|
||||
"""
|
||||
result = self.conn.execute(command_builder("browser.getUserContexts", {}))
|
||||
return [context_info["userContext"] for context_info in result["userContexts"]]
|
||||
|
||||
def remove_user_context(self, user_context_id: str) -> None:
|
||||
"""Removes a user context.
|
||||
|
||||
Args:
|
||||
user_context_id: The ID of the user context to remove.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user context ID is "default" or does not exist.
|
||||
"""
|
||||
if user_context_id == "default":
|
||||
raise ValueError("Cannot remove the default user context")
|
||||
|
||||
params = {"userContext": user_context_id}
|
||||
self.conn.execute(command_builder("browser.removeUserContext", params))
|
||||
|
||||
def get_client_windows(self) -> list[ClientWindowInfo]:
|
||||
"""Gets all client windows.
|
||||
|
||||
Returns:
|
||||
List[ClientWindowInfo]: A list of client window information.
|
||||
"""
|
||||
result = self.conn.execute(command_builder("browser.getClientWindows", {}))
|
||||
return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]]
|
||||
|
||||
def set_download_behavior(
|
||||
self,
|
||||
*,
|
||||
allowed: bool | None = None,
|
||||
destination_folder: str | os.PathLike | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set the download behavior for the browser or specific user contexts.
|
||||
|
||||
Args:
|
||||
allowed: True to allow downloads, False to deny downloads, or None to
|
||||
clear download behavior (revert to default).
|
||||
destination_folder: Required when allowed is True. Specifies the folder
|
||||
to store downloads in.
|
||||
user_contexts: Optional list of user context IDs to apply this
|
||||
behavior to. If omitted, updates the default behavior.
|
||||
|
||||
Raises:
|
||||
ValueError: If allowed=True and destination_folder is missing, or if
|
||||
allowed=False and destination_folder is provided.
|
||||
"""
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if allowed is None:
|
||||
params["downloadBehavior"] = None
|
||||
else:
|
||||
if allowed:
|
||||
if not destination_folder:
|
||||
raise ValueError("destination_folder is required when allowed=True.")
|
||||
params["downloadBehavior"] = {
|
||||
"type": "allowed",
|
||||
"destinationFolder": os.fspath(destination_folder),
|
||||
}
|
||||
else:
|
||||
if destination_folder:
|
||||
raise ValueError("destination_folder should not be provided when allowed=False.")
|
||||
params["downloadBehavior"] = {"type": "denied"}
|
||||
|
||||
if user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("browser.setDownloadBehavior", params))
|
||||
+1060
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,515 @@
|
||||
# The MIT License(MIT)
|
||||
#
|
||||
# Copyright(c) 2018 Hyperion Gray
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files(the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp
|
||||
|
||||
import contextvars
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import pathlib
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Generator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import trio
|
||||
from trio_websocket import ConnectionClosed as WsConnectionClosed
|
||||
from trio_websocket import connect_websocket_url
|
||||
|
||||
logger = logging.getLogger("trio_cdp")
|
||||
T = TypeVar("T")
|
||||
MAX_WS_MESSAGE_SIZE = 2**24
|
||||
|
||||
devtools = None
|
||||
version = None
|
||||
|
||||
|
||||
def import_devtools(ver):
|
||||
"""Attempt to load the current latest available devtools into the module cache for use later."""
|
||||
global devtools
|
||||
global version
|
||||
version = ver
|
||||
base = "selenium.webdriver.common.devtools.v"
|
||||
try:
|
||||
devtools = importlib.import_module(f"{base}{ver}")
|
||||
return devtools
|
||||
except ModuleNotFoundError:
|
||||
# Attempt to parse and load the 'most recent' devtools module. This is likely
|
||||
# because cdp has been updated but selenium python has not been released yet.
|
||||
devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools")
|
||||
versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir() and f.name != "latest")
|
||||
latest = max(int(x[1:]) for x in versions)
|
||||
selenium_logger = logging.getLogger(__name__)
|
||||
selenium_logger.debug("Falling back to loading `devtools`: v%s", latest)
|
||||
devtools = importlib.import_module(f"{base}{latest}")
|
||||
return devtools
|
||||
|
||||
|
||||
_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context")
|
||||
_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context")
|
||||
|
||||
|
||||
def get_connection_context(fn_name):
|
||||
"""Look up the current connection.
|
||||
|
||||
If there is no current connection, raise a ``RuntimeError`` with a
|
||||
helpful message.
|
||||
"""
|
||||
try:
|
||||
return _connection_context.get()
|
||||
except LookupError:
|
||||
raise RuntimeError(f"{fn_name}() must be called in a connection context.")
|
||||
|
||||
|
||||
def get_session_context(fn_name):
|
||||
"""Look up the current session.
|
||||
|
||||
If there is no current session, raise a ``RuntimeError`` with a
|
||||
helpful message.
|
||||
"""
|
||||
try:
|
||||
return _session_context.get()
|
||||
except LookupError:
|
||||
raise RuntimeError(f"{fn_name}() must be called in a session context.")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def connection_context(connection):
|
||||
"""Context manager installs ``connection`` as the session context for the current Trio task."""
|
||||
token = _connection_context.set(connection)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_connection_context.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_context(session):
|
||||
"""Context manager installs ``session`` as the session context for the current Trio task."""
|
||||
token = _session_context.set(session)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_session_context.reset(token)
|
||||
|
||||
|
||||
def set_global_connection(connection):
|
||||
"""Install ``connection`` in the root context so that it will become the default connection for all tasks.
|
||||
|
||||
This is generally not recommended, except it may be necessary in
|
||||
certain use cases such as running inside Jupyter notebook.
|
||||
"""
|
||||
global _connection_context
|
||||
_connection_context = contextvars.ContextVar("_connection_context", default=connection)
|
||||
|
||||
|
||||
def set_global_session(session):
|
||||
"""Install ``session`` in the root context so that it will become the default session for all tasks.
|
||||
|
||||
This is generally not recommended, except it may be necessary in
|
||||
certain use cases such as running inside Jupyter notebook.
|
||||
"""
|
||||
global _session_context
|
||||
_session_context = contextvars.ContextVar("_session_context", default=session)
|
||||
|
||||
|
||||
class BrowserError(Exception):
|
||||
"""This exception is raised when the browser's response to a command indicates that an error occurred."""
|
||||
|
||||
def __init__(self, obj):
|
||||
self.code = obj.get("code")
|
||||
self.message = obj.get("message")
|
||||
self.detail = obj.get("data")
|
||||
|
||||
def __str__(self):
|
||||
return f"BrowserError<code={self.code} message={self.message}> {self.detail}"
|
||||
|
||||
|
||||
class CdpConnectionClosed(WsConnectionClosed):
|
||||
"""Raised when a public method is called on a closed CDP connection."""
|
||||
|
||||
def __init__(self, reason):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
reason: wsproto.frame_protocol.CloseReason
|
||||
"""
|
||||
self.reason = reason
|
||||
|
||||
def __repr__(self):
|
||||
"""Return representation."""
|
||||
return f"{self.__class__.__name__}<{self.reason}>"
|
||||
|
||||
|
||||
class InternalError(Exception):
|
||||
"""This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CmEventProxy:
|
||||
"""A proxy object returned by :meth:`CdpBase.wait_for()``.
|
||||
|
||||
After the context manager executes, this proxy object will have a
|
||||
value set that contains the returned event.
|
||||
"""
|
||||
|
||||
value: Any = None
|
||||
|
||||
|
||||
class CdpBase:
|
||||
def __init__(self, ws, session_id, target_id):
|
||||
self.ws = ws
|
||||
self.session_id = session_id
|
||||
self.target_id = target_id
|
||||
self.channels = defaultdict(set)
|
||||
self.id_iter = itertools.count()
|
||||
self.inflight_cmd = {}
|
||||
self.inflight_result = {}
|
||||
|
||||
async def execute(self, cmd: Generator[dict, T, Any]) -> T:
|
||||
"""Execute a command on the server and wait for the result.
|
||||
|
||||
Args:
|
||||
cmd: any CDP command
|
||||
|
||||
Returns:
|
||||
a CDP result
|
||||
"""
|
||||
cmd_id = next(self.id_iter)
|
||||
cmd_event = trio.Event()
|
||||
self.inflight_cmd[cmd_id] = cmd, cmd_event
|
||||
request = next(cmd)
|
||||
request["id"] = cmd_id
|
||||
if self.session_id:
|
||||
request["sessionId"] = self.session_id
|
||||
request_str = json.dumps(request)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}")
|
||||
try:
|
||||
await self.ws.send_message(request_str)
|
||||
except WsConnectionClosed as wcc:
|
||||
raise CdpConnectionClosed(wcc.reason) from None
|
||||
await cmd_event.wait()
|
||||
response = self.inflight_result.pop(cmd_id)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Received CDP message: {response}")
|
||||
if isinstance(response, Exception):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}")
|
||||
raise response
|
||||
return response
|
||||
|
||||
def listen(self, *event_types, buffer_size=10):
|
||||
"""Listen for events.
|
||||
|
||||
Returns:
|
||||
An async iterator that iterates over events matching the indicated types.
|
||||
"""
|
||||
sender, receiver = trio.open_memory_channel(buffer_size)
|
||||
for event_type in event_types:
|
||||
self.channels[event_type].add(sender)
|
||||
return receiver
|
||||
|
||||
@asynccontextmanager
|
||||
async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]:
|
||||
"""Wait for an event of the given type and return it.
|
||||
|
||||
This is an async context manager, so you should open it inside
|
||||
an async with block. The block will not exit until the indicated
|
||||
event is received.
|
||||
"""
|
||||
sender: trio.MemorySendChannel
|
||||
receiver: trio.MemoryReceiveChannel
|
||||
sender, receiver = trio.open_memory_channel(buffer_size)
|
||||
self.channels[event_type].add(sender)
|
||||
proxy = CmEventProxy()
|
||||
yield proxy
|
||||
async with receiver:
|
||||
event = await receiver.receive()
|
||||
proxy.value = event
|
||||
|
||||
def _handle_data(self, data):
|
||||
"""Handle incoming WebSocket data.
|
||||
|
||||
Args:
|
||||
data: a JSON dictionary
|
||||
"""
|
||||
if "id" in data:
|
||||
self._handle_cmd_response(data)
|
||||
else:
|
||||
self._handle_event(data)
|
||||
|
||||
def _handle_cmd_response(self, data: dict):
|
||||
"""Handle a response to a command.
|
||||
|
||||
This will set an event flag that will return control to the
|
||||
task that called the command.
|
||||
|
||||
Args:
|
||||
data: response as a JSON dictionary
|
||||
"""
|
||||
cmd_id = data["id"]
|
||||
try:
|
||||
cmd, event = self.inflight_cmd.pop(cmd_id)
|
||||
except KeyError:
|
||||
logger.warning("Got a message with a command ID that does not exist: %s", data)
|
||||
return
|
||||
if "error" in data:
|
||||
# If the server reported an error, convert it to an exception and do
|
||||
# not process the response any further.
|
||||
self.inflight_result[cmd_id] = BrowserError(data["error"])
|
||||
else:
|
||||
# Otherwise, continue the generator to parse the JSON result
|
||||
# into a CDP object.
|
||||
try:
|
||||
_ = cmd.send(data["result"])
|
||||
raise InternalError("The command's generator function did not exit when expected!")
|
||||
except StopIteration as exit:
|
||||
return_ = exit.value
|
||||
self.inflight_result[cmd_id] = return_
|
||||
event.set()
|
||||
|
||||
def _handle_event(self, data: dict):
|
||||
"""Handle an event.
|
||||
|
||||
Args:
|
||||
data: event as a JSON dictionary
|
||||
"""
|
||||
global devtools
|
||||
if devtools is None:
|
||||
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
|
||||
event = devtools.util.parse_json_event(data)
|
||||
logger.debug("Received event: %s", event)
|
||||
to_remove = set()
|
||||
for sender in self.channels[type(event)]:
|
||||
try:
|
||||
sender.send_nowait(event)
|
||||
except trio.WouldBlock:
|
||||
logger.error('Unable to send event "%r" due to full channel %s', event, sender)
|
||||
except trio.BrokenResourceError:
|
||||
to_remove.add(sender)
|
||||
if to_remove:
|
||||
self.channels[type(event)] -= to_remove
|
||||
|
||||
|
||||
class CdpSession(CdpBase):
|
||||
"""Contains the state for a CDP session.
|
||||
|
||||
Generally you should not instantiate this object yourself; you should call
|
||||
:meth:`CdpConnection.open_session`.
|
||||
"""
|
||||
|
||||
def __init__(self, ws, session_id, target_id):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
ws: trio_websocket.WebSocketConnection
|
||||
session_id: devtools.target.SessionID
|
||||
target_id: devtools.target.TargetID
|
||||
"""
|
||||
super().__init__(ws, session_id, target_id)
|
||||
|
||||
self._dom_enable_count = 0
|
||||
self._dom_enable_lock = trio.Lock()
|
||||
self._page_enable_count = 0
|
||||
self._page_enable_lock = trio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def dom_enable(self):
|
||||
"""Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``.
|
||||
|
||||
This keeps track of concurrent callers and only disables DOM
|
||||
events when all callers have exited.
|
||||
"""
|
||||
global devtools
|
||||
async with self._dom_enable_lock:
|
||||
self._dom_enable_count += 1
|
||||
if self._dom_enable_count == 1:
|
||||
await self.execute(devtools.dom.enable())
|
||||
|
||||
yield
|
||||
|
||||
async with self._dom_enable_lock:
|
||||
self._dom_enable_count -= 1
|
||||
if self._dom_enable_count == 0:
|
||||
await self.execute(devtools.dom.disable())
|
||||
|
||||
@asynccontextmanager
|
||||
async def page_enable(self):
|
||||
"""Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits.
|
||||
|
||||
This keeps track of concurrent callers and only disables page
|
||||
events when all callers have exited.
|
||||
"""
|
||||
global devtools
|
||||
async with self._page_enable_lock:
|
||||
self._page_enable_count += 1
|
||||
if self._page_enable_count == 1:
|
||||
await self.execute(devtools.page.enable())
|
||||
|
||||
yield
|
||||
|
||||
async with self._page_enable_lock:
|
||||
self._page_enable_count -= 1
|
||||
if self._page_enable_count == 0:
|
||||
await self.execute(devtools.page.disable())
|
||||
|
||||
|
||||
class CdpConnection(CdpBase, trio.abc.AsyncResource):
|
||||
"""Contains the connection state for a Chrome DevTools Protocol server.
|
||||
|
||||
CDP can multiplex multiple "sessions" over a single connection. This
|
||||
class corresponds to the "root" session, i.e. the implicitly created
|
||||
session that has no session ID. This class is responsible for
|
||||
reading incoming WebSocket messages and forwarding them to the
|
||||
corresponding session, as well as handling messages targeted at the
|
||||
root session itself. You should generally call the
|
||||
:func:`open_cdp()` instead of instantiating this class directly.
|
||||
"""
|
||||
|
||||
def __init__(self, ws):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
ws: trio_websocket.WebSocketConnection
|
||||
"""
|
||||
super().__init__(ws, session_id=None, target_id=None)
|
||||
self.sessions = {}
|
||||
|
||||
async def aclose(self):
|
||||
"""Close the underlying WebSocket connection.
|
||||
|
||||
This will cause the reader task to gracefully exit when it tries
|
||||
to read the next message from the WebSocket. All of the public
|
||||
APIs (``execute()``, ``listen()``, etc.) will raise
|
||||
``CdpConnectionClosed`` after the CDP connection is closed. It
|
||||
is safe to call this multiple times.
|
||||
"""
|
||||
await self.ws.aclose()
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_session(self, target_id) -> AsyncIterator[CdpSession]:
|
||||
"""Context manager opens a session and enables the "simple" style of calling CDP APIs.
|
||||
|
||||
For example, inside a session context, you can call ``await
|
||||
dom.get_document()`` and it will execute on the current session
|
||||
automatically.
|
||||
"""
|
||||
session = await self.connect_session(target_id)
|
||||
with session_context(session):
|
||||
yield session
|
||||
|
||||
async def connect_session(self, target_id) -> "CdpSession":
|
||||
"""Returns a new :class:`CdpSession` connected to the specified target."""
|
||||
global devtools
|
||||
if devtools is None:
|
||||
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
|
||||
session_id = await self.execute(devtools.target.attach_to_target(target_id, True))
|
||||
session = CdpSession(self.ws, session_id, target_id)
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
async def _reader_task(self):
|
||||
"""Runs in the background and handles incoming messages.
|
||||
|
||||
Dispatches responses to commands and events to listeners.
|
||||
"""
|
||||
global devtools
|
||||
if devtools is None:
|
||||
raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.")
|
||||
while True:
|
||||
try:
|
||||
message = await self.ws.get_message()
|
||||
except WsConnectionClosed:
|
||||
# If the WebSocket is closed, we don't want to throw an
|
||||
# exception from the reader task. Instead we will throw
|
||||
# exceptions from the public API methods, and we can quietly
|
||||
# exit the reader task here.
|
||||
break
|
||||
try:
|
||||
data = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message})
|
||||
logger.debug("Received message %r", data)
|
||||
if "sessionId" in data:
|
||||
session_id = devtools.target.SessionID(data["sessionId"])
|
||||
try:
|
||||
session = self.sessions[session_id]
|
||||
except KeyError:
|
||||
raise BrowserError(
|
||||
{
|
||||
"code": -32700,
|
||||
"message": "Browser sent a message for an invalid session",
|
||||
"data": f"{session_id!r}",
|
||||
}
|
||||
)
|
||||
session._handle_data(data)
|
||||
else:
|
||||
self._handle_data(data)
|
||||
|
||||
for _, session in self.sessions.items():
|
||||
for _, senders in session.channels.items():
|
||||
for sender in senders:
|
||||
sender.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_cdp(url) -> AsyncIterator[CdpConnection]:
|
||||
"""Async context manager opens a connection to the browser then closes the connection when the block exits.
|
||||
|
||||
The context manager also sets the connection as the default
|
||||
connection for the current task, so that commands like ``await
|
||||
target.get_targets()`` will run on this connection automatically. If
|
||||
you want to use multiple connections concurrently, it is recommended
|
||||
to open each on in a separate task.
|
||||
"""
|
||||
async with trio.open_nursery() as nursery:
|
||||
conn = await connect_cdp(nursery, url)
|
||||
try:
|
||||
with connection_context(conn):
|
||||
yield conn
|
||||
finally:
|
||||
await conn.aclose()
|
||||
|
||||
|
||||
async def connect_cdp(nursery, url) -> CdpConnection:
|
||||
"""Connect to the browser specified by ``url`` and spawn a background task in the specified nursery.
|
||||
|
||||
The ``open_cdp()`` context manager is preferred in most situations.
|
||||
You should only use this function if you need to specify a custom
|
||||
nursery. This connection is not automatically closed! You can either
|
||||
use the connection object as a context manager (``async with
|
||||
conn:``) or else call ``await conn.aclose()`` on it when you are
|
||||
done with it. If ``set_context`` is True, then the returned
|
||||
connection will be installed as the default connection for the
|
||||
current task. This argument is for unusual use cases, such as
|
||||
running inside of a notebook.
|
||||
"""
|
||||
ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE)
|
||||
cdp_conn = CdpConnection(ws)
|
||||
nursery.start_soon(cdp_conn._reader_task)
|
||||
return cdp_conn
|
||||
@@ -0,0 +1,36 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
def command_builder(method: str, params: dict | None = None) -> Generator[dict, dict, dict]:
|
||||
"""Build a command iterator to send to the BiDi protocol.
|
||||
|
||||
Args:
|
||||
method: The method to execute.
|
||||
params: The parameters to pass to the method. Default is None.
|
||||
|
||||
Returns:
|
||||
The response from the command execution.
|
||||
"""
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
command = {"method": method, "params": params}
|
||||
cmd = yield command
|
||||
return cmd
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Console(Enum):
|
||||
ALL = "all"
|
||||
LOG = "log"
|
||||
ERROR = "error"
|
||||
+524
@@ -0,0 +1,524 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
|
||||
|
||||
|
||||
class ScreenOrientationNatural(Enum):
|
||||
"""Natural screen orientation."""
|
||||
|
||||
PORTRAIT = "portrait"
|
||||
LANDSCAPE = "landscape"
|
||||
|
||||
|
||||
class ScreenOrientationType(Enum):
|
||||
"""Screen orientation type."""
|
||||
|
||||
PORTRAIT_PRIMARY = "portrait-primary"
|
||||
PORTRAIT_SECONDARY = "portrait-secondary"
|
||||
LANDSCAPE_PRIMARY = "landscape-primary"
|
||||
LANDSCAPE_SECONDARY = "landscape-secondary"
|
||||
|
||||
|
||||
E = TypeVar("E", ScreenOrientationNatural, ScreenOrientationType)
|
||||
|
||||
|
||||
def _convert_to_enum(value: E | str, enum_class: type[E]) -> E:
|
||||
if isinstance(value, enum_class):
|
||||
return value
|
||||
assert isinstance(value, str)
|
||||
try:
|
||||
return enum_class(value.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid orientation: {value}")
|
||||
|
||||
|
||||
class ScreenOrientation:
|
||||
"""Represents screen orientation configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
natural: ScreenOrientationNatural | str,
|
||||
type: ScreenOrientationType | str,
|
||||
):
|
||||
"""Initialize ScreenOrientation.
|
||||
|
||||
Args:
|
||||
natural: Natural screen orientation ("portrait" or "landscape").
|
||||
type: Screen orientation type ("portrait-primary", "portrait-secondary",
|
||||
"landscape-primary", or "landscape-secondary").
|
||||
|
||||
Raises:
|
||||
ValueError: If natural or type values are invalid.
|
||||
"""
|
||||
# handle string values
|
||||
self.natural = _convert_to_enum(natural, ScreenOrientationNatural)
|
||||
self.type = _convert_to_enum(type, ScreenOrientationType)
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {
|
||||
"natural": self.natural.value,
|
||||
"type": self.type.value,
|
||||
}
|
||||
|
||||
|
||||
class GeolocationCoordinates:
|
||||
"""Represents geolocation coordinates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
accuracy: float = 1.0,
|
||||
altitude: float | None = None,
|
||||
altitude_accuracy: float | None = None,
|
||||
heading: float | None = None,
|
||||
speed: float | None = None,
|
||||
):
|
||||
"""Initialize GeolocationCoordinates.
|
||||
|
||||
Args:
|
||||
latitude: Latitude coordinate (-90.0 to 90.0).
|
||||
longitude: Longitude coordinate (-180.0 to 180.0).
|
||||
accuracy: Accuracy in meters (>= 0.0), defaults to 1.0.
|
||||
altitude: Altitude in meters or None, defaults to None.
|
||||
altitude_accuracy: Altitude accuracy in meters (>= 0.0) or None, defaults to None.
|
||||
heading: Heading in degrees (0.0 to 360.0) or None, defaults to None.
|
||||
speed: Speed in meters per second (>= 0.0) or None, defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If coordinates are out of valid range or if altitude_accuracy is provided without altitude.
|
||||
"""
|
||||
self.latitude = latitude
|
||||
self.longitude = longitude
|
||||
self.accuracy = accuracy
|
||||
self.altitude = altitude
|
||||
self.altitude_accuracy = altitude_accuracy
|
||||
self.heading = heading
|
||||
self.speed = speed
|
||||
|
||||
@property
|
||||
def latitude(self) -> float:
|
||||
return self._latitude
|
||||
|
||||
@latitude.setter
|
||||
def latitude(self, value: float) -> None:
|
||||
if not (-90.0 <= value <= 90.0):
|
||||
raise ValueError("latitude must be between -90.0 and 90.0")
|
||||
self._latitude = value
|
||||
|
||||
@property
|
||||
def longitude(self) -> float:
|
||||
return self._longitude
|
||||
|
||||
@longitude.setter
|
||||
def longitude(self, value: float) -> None:
|
||||
if not (-180.0 <= value <= 180.0):
|
||||
raise ValueError("longitude must be between -180.0 and 180.0")
|
||||
self._longitude = value
|
||||
|
||||
@property
|
||||
def accuracy(self) -> float:
|
||||
return self._accuracy
|
||||
|
||||
@accuracy.setter
|
||||
def accuracy(self, value: float) -> None:
|
||||
if value < 0.0:
|
||||
raise ValueError("accuracy must be >= 0.0")
|
||||
self._accuracy = value
|
||||
|
||||
@property
|
||||
def altitude(self) -> float | None:
|
||||
return self._altitude
|
||||
|
||||
@altitude.setter
|
||||
def altitude(self, value: float | None) -> None:
|
||||
self._altitude = value
|
||||
|
||||
@property
|
||||
def altitude_accuracy(self) -> float | None:
|
||||
return self._altitude_accuracy
|
||||
|
||||
@altitude_accuracy.setter
|
||||
def altitude_accuracy(self, value: float | None) -> None:
|
||||
if value is not None and self.altitude is None:
|
||||
raise ValueError("altitude_accuracy cannot be set without altitude")
|
||||
if value is not None and value < 0.0:
|
||||
raise ValueError("altitude_accuracy must be >= 0.0")
|
||||
self._altitude_accuracy = value
|
||||
|
||||
@property
|
||||
def heading(self) -> float | None:
|
||||
return self._heading
|
||||
|
||||
@heading.setter
|
||||
def heading(self, value: float | None) -> None:
|
||||
if value is not None and not (0.0 <= value < 360.0):
|
||||
raise ValueError("heading must be between 0.0 and 360.0")
|
||||
self._heading = value
|
||||
|
||||
@property
|
||||
def speed(self) -> float | None:
|
||||
return self._speed
|
||||
|
||||
@speed.setter
|
||||
def speed(self, value: float | None) -> None:
|
||||
if value is not None and value < 0.0:
|
||||
raise ValueError("speed must be >= 0.0")
|
||||
self._speed = value
|
||||
|
||||
def to_dict(self) -> dict[str, float | None]:
|
||||
result: dict[str, float | None] = {
|
||||
"latitude": self.latitude,
|
||||
"longitude": self.longitude,
|
||||
"accuracy": self.accuracy,
|
||||
}
|
||||
|
||||
if self.altitude is not None:
|
||||
result["altitude"] = self.altitude
|
||||
|
||||
if self.altitude_accuracy is not None:
|
||||
result["altitudeAccuracy"] = self.altitude_accuracy
|
||||
|
||||
if self.heading is not None:
|
||||
result["heading"] = self.heading
|
||||
|
||||
if self.speed is not None:
|
||||
result["speed"] = self.speed
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class GeolocationPositionError:
|
||||
"""Represents a geolocation position error."""
|
||||
|
||||
TYPE_POSITION_UNAVAILABLE = "positionUnavailable"
|
||||
|
||||
def __init__(self, type: str = TYPE_POSITION_UNAVAILABLE):
|
||||
if type != self.TYPE_POSITION_UNAVAILABLE:
|
||||
raise ValueError(f'type must be "{self.TYPE_POSITION_UNAVAILABLE}"')
|
||||
self.type = type
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"type": self.type}
|
||||
|
||||
|
||||
class Emulation:
|
||||
"""BiDi implementation of the emulation module."""
|
||||
|
||||
def __init__(self, conn: WebSocketConnection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
def set_geolocation_override(
|
||||
self,
|
||||
coordinates: GeolocationCoordinates | None = None,
|
||||
error: GeolocationPositionError | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set geolocation override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
coordinates: Geolocation coordinates to emulate, or None.
|
||||
error: Geolocation error to emulate, or None.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both coordinates and error are provided, or if both contexts
|
||||
and user_contexts are provided, or if neither contexts nor
|
||||
user_contexts are provided.
|
||||
"""
|
||||
if coordinates is not None and error is not None:
|
||||
raise ValueError("Cannot specify both coordinates and error")
|
||||
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and userContexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or userContexts")
|
||||
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if coordinates is not None:
|
||||
params["coordinates"] = coordinates.to_dict()
|
||||
elif error is not None:
|
||||
params["error"] = error.to_dict()
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setGeolocationOverride", params))
|
||||
|
||||
def set_timezone_override(
|
||||
self,
|
||||
timezone: str | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set timezone override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
timezone: Timezone identifier (IANA timezone name or offset string like '+01:00'),
|
||||
or None to clear the override.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided, or if neither
|
||||
contexts nor user_contexts are provided.
|
||||
"""
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and user_contexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or user_contexts")
|
||||
|
||||
params: dict[str, Any] = {"timezone": timezone}
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setTimezoneOverride", params))
|
||||
|
||||
def set_locale_override(
|
||||
self,
|
||||
locale: str | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set locale override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
locale: Locale string as per BCP 47, or None to clear override.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided, or if neither
|
||||
contexts nor user_contexts are provided, or if locale is invalid.
|
||||
"""
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and userContexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or userContexts")
|
||||
|
||||
params: dict[str, Any] = {"locale": locale}
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setLocaleOverride", params))
|
||||
|
||||
def set_scripting_enabled(
|
||||
self,
|
||||
enabled: bool | None = False,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set scripting enabled override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
enabled: False to disable scripting, None to clear the override.
|
||||
Note: Only emulation of disabled JavaScript is supported.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided, or if neither
|
||||
contexts nor user_contexts are provided, or if enabled is True.
|
||||
"""
|
||||
if enabled:
|
||||
raise ValueError("Only emulation of disabled JavaScript is supported (enabled must be False or None)")
|
||||
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and userContexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or userContexts")
|
||||
|
||||
params: dict[str, Any] = {"enabled": enabled}
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setScriptingEnabled", params))
|
||||
|
||||
def set_screen_orientation_override(
|
||||
self,
|
||||
screen_orientation: ScreenOrientation | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set screen orientation override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
screen_orientation: ScreenOrientation object to emulate, or None to clear the override.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided, or if neither
|
||||
contexts nor user_contexts are provided.
|
||||
"""
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and userContexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or userContexts")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"screenOrientation": screen_orientation.to_dict() if screen_orientation is not None else None
|
||||
}
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setScreenOrientationOverride", params))
|
||||
|
||||
def set_user_agent_override(
|
||||
self,
|
||||
user_agent: str | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set user agent override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
user_agent: User agent string to emulate, or None to clear the override.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided, or if neither
|
||||
contexts nor user_contexts are provided.
|
||||
"""
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and user_contexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or user_contexts")
|
||||
|
||||
params: dict[str, Any] = {"userAgent": user_agent}
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setUserAgentOverride", params))
|
||||
|
||||
def set_network_conditions(
|
||||
self,
|
||||
offline: bool = False,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set network conditions for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
offline: True to emulate offline network conditions, False to clear the override.
|
||||
contexts: List of browsing context IDs to apply the conditions to.
|
||||
user_contexts: List of user context IDs to apply the conditions to.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided, or if neither
|
||||
contexts nor user_contexts are provided.
|
||||
"""
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and user_contexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or user_contexts")
|
||||
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if offline:
|
||||
params["networkConditions"] = {"type": "offline"}
|
||||
else:
|
||||
# if offline is False or None, then clear the override
|
||||
params["networkConditions"] = None
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setNetworkConditions", params))
|
||||
|
||||
def set_screen_settings_override(
|
||||
self,
|
||||
width: int | None = None,
|
||||
height: int | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Set screen settings override for the given contexts or user contexts.
|
||||
|
||||
Args:
|
||||
width: Screen width in pixels (>= 0). None to clear the override.
|
||||
height: Screen height in pixels (>= 0). None to clear the override.
|
||||
contexts: List of browsing context IDs to apply the override to.
|
||||
user_contexts: List of user context IDs to apply the override to.
|
||||
|
||||
Raises:
|
||||
ValueError: If only one of width/height is provided, or if both contexts
|
||||
and user_contexts are provided, or if neither is provided.
|
||||
"""
|
||||
if (width is None) != (height is None):
|
||||
raise ValueError("Must provide both width and height, or neither to clear the override")
|
||||
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and user_contexts")
|
||||
|
||||
if contexts is None and user_contexts is None:
|
||||
raise ValueError("Must specify either contexts or user_contexts")
|
||||
|
||||
screen_area = None
|
||||
if width is not None and height is not None:
|
||||
if not isinstance(width, int) or not isinstance(height, int):
|
||||
raise ValueError("width and height must be integers")
|
||||
if width < 0 or height < 0:
|
||||
raise ValueError("width and height must be >= 0")
|
||||
screen_area = {"width": width, "height": height}
|
||||
|
||||
params: dict[str, Any] = {"screenArea": screen_area}
|
||||
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
elif user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
|
||||
self.conn.execute(command_builder("emulation.setScreenSettingsOverride", params))
|
||||
@@ -0,0 +1,462 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
from selenium.webdriver.common.bidi.session import Session
|
||||
|
||||
|
||||
class PointerType:
|
||||
"""Represents the possible pointer types."""
|
||||
|
||||
MOUSE = "mouse"
|
||||
PEN = "pen"
|
||||
TOUCH = "touch"
|
||||
|
||||
VALID_TYPES = {MOUSE, PEN, TOUCH}
|
||||
|
||||
|
||||
class Origin:
|
||||
"""Represents the possible origin types."""
|
||||
|
||||
VIEWPORT = "viewport"
|
||||
POINTER = "pointer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElementOrigin:
|
||||
"""Represents an element origin for input actions."""
|
||||
|
||||
type: str
|
||||
element: dict
|
||||
|
||||
def __init__(self, element_reference: dict):
|
||||
self.type = "element"
|
||||
self.element = element_reference
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the ElementOrigin to a dictionary."""
|
||||
return {"type": self.type, "element": self.element}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerParameters:
|
||||
"""Represents pointer parameters for pointer actions."""
|
||||
|
||||
pointer_type: str = PointerType.MOUSE
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pointer_type not in PointerType.VALID_TYPES:
|
||||
raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PointerParameters to a dictionary."""
|
||||
return {"pointerType": self.pointer_type}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerCommonProperties:
|
||||
"""Common properties for pointer actions."""
|
||||
|
||||
width: int = 1
|
||||
height: int = 1
|
||||
pressure: float = 0.0
|
||||
tangential_pressure: float = 0.0
|
||||
twist: int = 0
|
||||
altitude_angle: float = 0.0
|
||||
azimuth_angle: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.width < 1:
|
||||
raise ValueError("width must be at least 1")
|
||||
if self.height < 1:
|
||||
raise ValueError("height must be at least 1")
|
||||
if not (0.0 <= self.pressure <= 1.0):
|
||||
raise ValueError("pressure must be between 0.0 and 1.0")
|
||||
if not (0.0 <= self.tangential_pressure <= 1.0):
|
||||
raise ValueError("tangential_pressure must be between 0.0 and 1.0")
|
||||
if not (0 <= self.twist <= 359):
|
||||
raise ValueError("twist must be between 0 and 359")
|
||||
if not (0.0 <= self.altitude_angle <= math.pi / 2):
|
||||
raise ValueError("altitude_angle must be between 0.0 and π/2")
|
||||
if not (0.0 <= self.azimuth_angle <= 2 * math.pi):
|
||||
raise ValueError("azimuth_angle must be between 0.0 and 2π")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PointerCommonProperties to a dictionary."""
|
||||
result: dict[str, Any] = {}
|
||||
if self.width != 1:
|
||||
result["width"] = self.width
|
||||
if self.height != 1:
|
||||
result["height"] = self.height
|
||||
if self.pressure != 0.0:
|
||||
result["pressure"] = self.pressure
|
||||
if self.tangential_pressure != 0.0:
|
||||
result["tangentialPressure"] = self.tangential_pressure
|
||||
if self.twist != 0:
|
||||
result["twist"] = self.twist
|
||||
if self.altitude_angle != 0.0:
|
||||
result["altitudeAngle"] = self.altitude_angle
|
||||
if self.azimuth_angle != 0.0:
|
||||
result["azimuthAngle"] = self.azimuth_angle
|
||||
return result
|
||||
|
||||
|
||||
# Action classes
|
||||
@dataclass
|
||||
class PauseAction:
|
||||
"""Represents a pause action."""
|
||||
|
||||
duration: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "pause"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PauseAction to a dictionary."""
|
||||
result: dict[str, Any] = {"type": self.type}
|
||||
if self.duration is not None:
|
||||
result["duration"] = self.duration
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyDownAction:
|
||||
"""Represents a key down action."""
|
||||
|
||||
value: str = ""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "keyDown"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the KeyDownAction to a dictionary."""
|
||||
return {"type": self.type, "value": self.value}
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyUpAction:
|
||||
"""Represents a key up action."""
|
||||
|
||||
value: str = ""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "keyUp"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the KeyUpAction to a dictionary."""
|
||||
return {"type": self.type, "value": self.value}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerDownAction:
|
||||
"""Represents a pointer down action."""
|
||||
|
||||
button: int = 0
|
||||
properties: PointerCommonProperties | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "pointerDown"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PointerDownAction to a dictionary."""
|
||||
result: dict[str, Any] = {"type": self.type, "button": self.button}
|
||||
if self.properties:
|
||||
result.update(self.properties.to_dict())
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerUpAction:
|
||||
"""Represents a pointer up action."""
|
||||
|
||||
button: int = 0
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "pointerUp"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PointerUpAction to a dictionary."""
|
||||
return {"type": self.type, "button": self.button}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerMoveAction:
|
||||
"""Represents a pointer move action."""
|
||||
|
||||
x: float = 0
|
||||
y: float = 0
|
||||
duration: int | None = None
|
||||
origin: str | ElementOrigin | None = None
|
||||
properties: PointerCommonProperties | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "pointerMove"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PointerMoveAction to a dictionary."""
|
||||
result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y}
|
||||
if self.duration is not None:
|
||||
result["duration"] = self.duration
|
||||
if self.origin is not None:
|
||||
if isinstance(self.origin, ElementOrigin):
|
||||
result["origin"] = self.origin.to_dict()
|
||||
else:
|
||||
result["origin"] = self.origin
|
||||
if self.properties:
|
||||
result.update(self.properties.to_dict())
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class WheelScrollAction:
|
||||
"""Represents a wheel scroll action."""
|
||||
|
||||
x: int = 0
|
||||
y: int = 0
|
||||
delta_x: int = 0
|
||||
delta_y: int = 0
|
||||
duration: int | None = None
|
||||
origin: str | ElementOrigin | None = Origin.VIEWPORT
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "scroll"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the WheelScrollAction to a dictionary."""
|
||||
result: dict[str, Any] = {
|
||||
"type": self.type,
|
||||
"x": self.x,
|
||||
"y": self.y,
|
||||
"deltaX": self.delta_x,
|
||||
"deltaY": self.delta_y,
|
||||
}
|
||||
if self.duration is not None:
|
||||
result["duration"] = self.duration
|
||||
if self.origin is not None:
|
||||
if isinstance(self.origin, ElementOrigin):
|
||||
result["origin"] = self.origin.to_dict()
|
||||
else:
|
||||
result["origin"] = self.origin
|
||||
return result
|
||||
|
||||
|
||||
# Source Actions
|
||||
@dataclass
|
||||
class NoneSourceActions:
|
||||
"""Represents a sequence of none actions."""
|
||||
|
||||
id: str = ""
|
||||
actions: list[PauseAction] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "none"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the NoneSourceActions to a dictionary."""
|
||||
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeySourceActions:
|
||||
"""Represents a sequence of key actions."""
|
||||
|
||||
id: str = ""
|
||||
actions: list[PauseAction | KeyDownAction | KeyUpAction] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "key"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the KeySourceActions to a dictionary."""
|
||||
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointerSourceActions:
|
||||
"""Represents a sequence of pointer actions."""
|
||||
|
||||
id: str = ""
|
||||
parameters: PointerParameters | None = None
|
||||
actions: list[PauseAction | PointerDownAction | PointerUpAction | PointerMoveAction] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.parameters is None:
|
||||
self.parameters = PointerParameters()
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "pointer"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the PointerSourceActions to a dictionary."""
|
||||
result: dict[str, Any] = {
|
||||
"type": self.type,
|
||||
"id": self.id,
|
||||
"actions": [action.to_dict() for action in self.actions],
|
||||
}
|
||||
if self.parameters:
|
||||
result["parameters"] = self.parameters.to_dict()
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class WheelSourceActions:
|
||||
"""Represents a sequence of wheel actions."""
|
||||
|
||||
id: str = ""
|
||||
actions: list[PauseAction | WheelScrollAction] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "wheel"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the WheelSourceActions to a dictionary."""
|
||||
return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileDialogInfo:
|
||||
"""Represents file dialog information from input.fileDialogOpened event."""
|
||||
|
||||
context: str
|
||||
multiple: bool
|
||||
element: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileDialogInfo":
|
||||
"""Creates a FileDialogInfo instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the file dialog information.
|
||||
|
||||
Returns:
|
||||
FileDialogInfo: A new instance of FileDialogInfo.
|
||||
"""
|
||||
return cls(context=data["context"], multiple=data["multiple"], element=data.get("element"))
|
||||
|
||||
|
||||
# Event Class
|
||||
class FileDialogOpened:
|
||||
"""Event class for input.fileDialogOpened event."""
|
||||
|
||||
event_class = "input.fileDialogOpened"
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json):
|
||||
"""Create FileDialogInfo from JSON data."""
|
||||
return FileDialogInfo.from_dict(json)
|
||||
|
||||
|
||||
class Input:
|
||||
"""BiDi implementation of the input module."""
|
||||
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.subscriptions = {}
|
||||
self.callbacks = {}
|
||||
|
||||
def perform_actions(
|
||||
self,
|
||||
context: str,
|
||||
actions: list[NoneSourceActions | KeySourceActions | PointerSourceActions | WheelSourceActions],
|
||||
) -> None:
|
||||
"""Performs a sequence of user input actions.
|
||||
|
||||
Args:
|
||||
context: The browsing context ID where actions should be performed.
|
||||
actions: A list of source actions to perform.
|
||||
"""
|
||||
params = {"context": context, "actions": [action.to_dict() for action in actions]}
|
||||
self.conn.execute(command_builder("input.performActions", params))
|
||||
|
||||
def release_actions(self, context: str) -> None:
|
||||
"""Releases all input state for the given context.
|
||||
|
||||
Args:
|
||||
context: The browsing context ID to release actions for.
|
||||
"""
|
||||
params = {"context": context}
|
||||
self.conn.execute(command_builder("input.releaseActions", params))
|
||||
|
||||
def set_files(self, context: str, element: dict, files: list[str]) -> None:
|
||||
"""Sets files for a file input element.
|
||||
|
||||
Args:
|
||||
context: The browsing context ID.
|
||||
element: The element reference (script.SharedReference).
|
||||
files: A list of file paths to set.
|
||||
"""
|
||||
params = {"context": context, "element": element, "files": files}
|
||||
self.conn.execute(command_builder("input.setFiles", params))
|
||||
|
||||
def add_file_dialog_handler(self, handler) -> int:
|
||||
"""Add a handler for file dialog opened events.
|
||||
|
||||
Args:
|
||||
handler: Callback function that takes a FileDialogInfo object.
|
||||
|
||||
Returns:
|
||||
int: Callback ID for removing the handler later.
|
||||
"""
|
||||
# Subscribe to the event if not already subscribed
|
||||
if FileDialogOpened.event_class not in self.subscriptions:
|
||||
session = Session(self.conn)
|
||||
self.conn.execute(session.subscribe(FileDialogOpened.event_class))
|
||||
self.subscriptions[FileDialogOpened.event_class] = []
|
||||
|
||||
# Add callback - the callback receives the parsed FileDialogInfo directly
|
||||
callback_id = self.conn.add_callback(FileDialogOpened, handler)
|
||||
|
||||
self.subscriptions[FileDialogOpened.event_class].append(callback_id)
|
||||
self.callbacks[callback_id] = handler
|
||||
|
||||
return callback_id
|
||||
|
||||
def remove_file_dialog_handler(self, callback_id: int) -> None:
|
||||
"""Remove a file dialog handler.
|
||||
|
||||
Args:
|
||||
callback_id: The callback ID returned by add_file_dialog_handler.
|
||||
"""
|
||||
if callback_id in self.callbacks:
|
||||
del self.callbacks[callback_id]
|
||||
|
||||
if FileDialogOpened.event_class in self.subscriptions:
|
||||
if callback_id in self.subscriptions[FileDialogOpened.event_class]:
|
||||
self.subscriptions[FileDialogOpened.event_class].remove(callback_id)
|
||||
|
||||
# If no more callbacks for this event, unsubscribe
|
||||
if not self.subscriptions[FileDialogOpened.event_class]:
|
||||
session = Session(self.conn)
|
||||
self.conn.execute(session.unsubscribe(FileDialogOpened.event_class))
|
||||
del self.subscriptions[FileDialogOpened.event_class]
|
||||
|
||||
self.conn.remove_callback(FileDialogOpened, callback_id)
|
||||
@@ -0,0 +1,81 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LogEntryAdded:
|
||||
event_class = "log.entryAdded"
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry | JavaScriptLogEntry | None:
|
||||
if json["type"] == "console":
|
||||
return ConsoleLogEntry.from_json(json)
|
||||
elif json["type"] == "javascript":
|
||||
return JavaScriptLogEntry.from_json(json)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsoleLogEntry:
|
||||
level: str
|
||||
text: str
|
||||
timestamp: str
|
||||
method: str
|
||||
args: list[dict[str, Any]]
|
||||
type_: str
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry:
|
||||
return cls(
|
||||
level=json["level"],
|
||||
text=json["text"],
|
||||
timestamp=json["timestamp"],
|
||||
method=json["method"],
|
||||
args=json["args"],
|
||||
type_=json["type"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaScriptLogEntry:
|
||||
level: str
|
||||
text: str
|
||||
timestamp: str
|
||||
stacktrace: dict[str, Any]
|
||||
type_: str
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> JavaScriptLogEntry:
|
||||
return cls(
|
||||
level=json["level"],
|
||||
text=json["text"],
|
||||
timestamp=json["timestamp"],
|
||||
stacktrace=json["stackTrace"],
|
||||
type_=json["type"],
|
||||
)
|
||||
|
||||
|
||||
class LogLevel:
|
||||
"""Represents log level."""
|
||||
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
+338
@@ -0,0 +1,338 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
|
||||
|
||||
|
||||
class NetworkEvent:
|
||||
"""Represents a network event."""
|
||||
|
||||
def __init__(self, event_class: str, **kwargs: Any) -> None:
|
||||
self.event_class = event_class
|
||||
self.params = kwargs
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> NetworkEvent:
|
||||
return cls(event_class=json.get("event_class", ""), **json)
|
||||
|
||||
|
||||
class Network:
|
||||
EVENTS = {
|
||||
"before_request": "network.beforeRequestSent",
|
||||
"response_started": "network.responseStarted",
|
||||
"response_completed": "network.responseCompleted",
|
||||
"auth_required": "network.authRequired",
|
||||
"fetch_error": "network.fetchError",
|
||||
"continue_request": "network.continueRequest",
|
||||
"continue_auth": "network.continueWithAuth",
|
||||
}
|
||||
|
||||
PHASES = {
|
||||
"before_request": "beforeRequestSent",
|
||||
"response_started": "responseStarted",
|
||||
"auth_required": "authRequired",
|
||||
}
|
||||
|
||||
def __init__(self, conn: WebSocketConnection) -> None:
|
||||
self.conn = conn
|
||||
self.intercepts: list[str] = []
|
||||
self.callbacks: dict[str | int, Any] = {}
|
||||
self.subscriptions: dict[str, list[int]] = {}
|
||||
|
||||
def _add_intercept(
|
||||
self,
|
||||
phases: list[str] | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
url_patterns: list[Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Add an intercept to the network.
|
||||
|
||||
Args:
|
||||
phases: A list of phases to intercept. Default is None (empty list).
|
||||
contexts: A list of contexts to intercept. Default is None.
|
||||
url_patterns: A list of URL patterns to intercept. Default is None.
|
||||
|
||||
Returns:
|
||||
str: intercept id
|
||||
"""
|
||||
if phases is None:
|
||||
phases = []
|
||||
params = {}
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
if url_patterns is not None:
|
||||
params["urlPatterns"] = url_patterns
|
||||
if len(phases) > 0:
|
||||
params["phases"] = phases
|
||||
else:
|
||||
params["phases"] = ["beforeRequestSent"]
|
||||
cmd = command_builder("network.addIntercept", params)
|
||||
|
||||
result: dict[str, Any] = self.conn.execute(cmd)
|
||||
self.intercepts.append(result["intercept"])
|
||||
return result
|
||||
|
||||
def _remove_intercept(self, intercept: str | None = None) -> None:
|
||||
"""Remove a specific intercept, or all intercepts.
|
||||
|
||||
Args:
|
||||
intercept: The intercept to remove. Default is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If intercept is not found.
|
||||
|
||||
Note:
|
||||
If intercept is None, all intercepts will be removed.
|
||||
"""
|
||||
if intercept is None:
|
||||
intercepts_to_remove = self.intercepts.copy() # create a copy before iterating
|
||||
for intercept_id in intercepts_to_remove: # remove all intercepts
|
||||
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id}))
|
||||
self.intercepts.remove(intercept_id)
|
||||
else:
|
||||
try:
|
||||
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept}))
|
||||
self.intercepts.remove(intercept)
|
||||
except Exception as e:
|
||||
raise Exception(f"Exception: {e}")
|
||||
|
||||
def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int:
|
||||
"""Set a callback function to subscribe to a network event.
|
||||
|
||||
Args:
|
||||
event_name: The event to subscribe to.
|
||||
callback: The callback function to execute on event.
|
||||
Takes Request object as argument.
|
||||
|
||||
Returns:
|
||||
int: callback id
|
||||
"""
|
||||
event = NetworkEvent(event_name)
|
||||
|
||||
def _callback(event_data: NetworkEvent) -> None:
|
||||
request = Request(
|
||||
network=self,
|
||||
request_id=event_data.params["request"].get("request", None),
|
||||
body_size=event_data.params["request"].get("bodySize", None),
|
||||
cookies=event_data.params["request"].get("cookies", None),
|
||||
resource_type=event_data.params["request"].get("goog:resourceType", None),
|
||||
headers=event_data.params["request"].get("headers", None),
|
||||
headers_size=event_data.params["request"].get("headersSize", None),
|
||||
timings=event_data.params["request"].get("timings", None),
|
||||
url=event_data.params["request"].get("url", None),
|
||||
)
|
||||
callback(request)
|
||||
|
||||
callback_id: int = self.conn.add_callback(event, _callback)
|
||||
|
||||
if event_name in self.callbacks:
|
||||
self.callbacks[event_name].append(callback_id)
|
||||
else:
|
||||
self.callbacks[event_name] = [callback_id]
|
||||
|
||||
return callback_id
|
||||
|
||||
def add_request_handler(
|
||||
self,
|
||||
event: str,
|
||||
callback: Callable[[Request], Any],
|
||||
url_patterns: list[Any] | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
) -> int:
|
||||
"""Add a request handler to the network.
|
||||
|
||||
Args:
|
||||
event: The event to subscribe to.
|
||||
callback: The callback function to execute on request interception.
|
||||
Takes Request object as argument.
|
||||
url_patterns: A list of URL patterns to intercept. Default is None.
|
||||
contexts: A list of contexts to intercept. Default is None.
|
||||
|
||||
Returns:
|
||||
int: callback id
|
||||
"""
|
||||
try:
|
||||
event_name = self.EVENTS[event]
|
||||
phase_name = self.PHASES[event]
|
||||
except KeyError:
|
||||
raise Exception(f"Event {event} not found")
|
||||
|
||||
result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts)
|
||||
callback_id = self._on_request(event_name, callback)
|
||||
|
||||
if event_name in self.subscriptions:
|
||||
self.subscriptions[event_name].append(callback_id)
|
||||
else:
|
||||
params: dict[str, Any] = {}
|
||||
params["events"] = [event_name]
|
||||
self.conn.execute(command_builder("session.subscribe", params))
|
||||
self.subscriptions[event_name] = [callback_id]
|
||||
|
||||
self.callbacks[callback_id] = result["intercept"]
|
||||
return callback_id
|
||||
|
||||
def remove_request_handler(self, event: str, callback_id: int) -> None:
|
||||
"""Remove a request handler from the network.
|
||||
|
||||
Args:
|
||||
event: The event to unsubscribe from.
|
||||
callback_id: The callback id to remove.
|
||||
"""
|
||||
try:
|
||||
event_name = self.EVENTS[event]
|
||||
except KeyError:
|
||||
raise Exception(f"Event {event} not found")
|
||||
|
||||
net_event = NetworkEvent(event_name)
|
||||
|
||||
self.conn.remove_callback(net_event, callback_id)
|
||||
self._remove_intercept(self.callbacks[callback_id])
|
||||
del self.callbacks[callback_id]
|
||||
self.subscriptions[event_name].remove(callback_id)
|
||||
if len(self.subscriptions[event_name]) == 0:
|
||||
params: dict[str, Any] = {}
|
||||
params["events"] = [event_name]
|
||||
self.conn.execute(command_builder("session.unsubscribe", params))
|
||||
del self.subscriptions[event_name]
|
||||
|
||||
def clear_request_handlers(self) -> None:
|
||||
"""Clear all request handlers from the network."""
|
||||
for event_name in self.subscriptions:
|
||||
net_event = NetworkEvent(event_name)
|
||||
for callback_id in self.subscriptions[event_name]:
|
||||
self.conn.remove_callback(net_event, callback_id)
|
||||
self._remove_intercept(self.callbacks[callback_id])
|
||||
del self.callbacks[callback_id]
|
||||
params: dict[str, Any] = {}
|
||||
params["events"] = [event_name]
|
||||
self.conn.execute(command_builder("session.unsubscribe", params))
|
||||
self.subscriptions = {}
|
||||
|
||||
def add_auth_handler(self, username: str, password: str) -> int:
|
||||
"""Add an authentication handler to the network.
|
||||
|
||||
Args:
|
||||
username: The username to authenticate with.
|
||||
password: The password to authenticate with.
|
||||
|
||||
Returns:
|
||||
int: callback id
|
||||
"""
|
||||
event = "auth_required"
|
||||
|
||||
def _callback(request: Request) -> None:
|
||||
request._continue_with_auth(username, password)
|
||||
|
||||
return self.add_request_handler(event, _callback)
|
||||
|
||||
def remove_auth_handler(self, callback_id: int) -> None:
|
||||
"""Remove an authentication handler from the network.
|
||||
|
||||
Args:
|
||||
callback_id: The callback id to remove.
|
||||
"""
|
||||
event = "auth_required"
|
||||
self.remove_request_handler(event, callback_id)
|
||||
|
||||
|
||||
class Request:
|
||||
"""Represents an intercepted network request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
network: Network,
|
||||
request_id: Any,
|
||||
body_size: int | None = None,
|
||||
cookies: Any = None,
|
||||
resource_type: str | None = None,
|
||||
headers: Any = None,
|
||||
headers_size: int | None = None,
|
||||
method: str | None = None,
|
||||
timings: Any = None,
|
||||
url: str | None = None,
|
||||
) -> None:
|
||||
self.network = network
|
||||
self.request_id = request_id
|
||||
self.body_size = body_size
|
||||
self.cookies = cookies
|
||||
self.resource_type = resource_type
|
||||
self.headers = headers
|
||||
self.headers_size = headers_size
|
||||
self.method = method
|
||||
self.timings = timings
|
||||
self.url = url
|
||||
|
||||
def fail_request(self) -> None:
|
||||
"""Fail this request."""
|
||||
if not self.request_id:
|
||||
raise ValueError("Request not found.")
|
||||
|
||||
params: dict[str, Any] = {"request": self.request_id}
|
||||
self.network.conn.execute(command_builder("network.failRequest", params))
|
||||
|
||||
def continue_request(
|
||||
self,
|
||||
body: Any = None,
|
||||
method: str | None = None,
|
||||
headers: Any = None,
|
||||
cookies: Any = None,
|
||||
url: str | None = None,
|
||||
) -> None:
|
||||
"""Continue after intercepting this request."""
|
||||
if not self.request_id:
|
||||
raise ValueError("Request not found.")
|
||||
|
||||
params: dict[str, Any] = {"request": self.request_id}
|
||||
if body is not None:
|
||||
params["body"] = body
|
||||
if method is not None:
|
||||
params["method"] = method
|
||||
if headers is not None:
|
||||
params["headers"] = headers
|
||||
if cookies is not None:
|
||||
params["cookies"] = cookies
|
||||
if url is not None:
|
||||
params["url"] = url
|
||||
|
||||
self.network.conn.execute(command_builder("network.continueRequest", params))
|
||||
|
||||
def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None:
|
||||
"""Continue with authentication.
|
||||
|
||||
Args:
|
||||
username: The username to authenticate with.
|
||||
password: The password to authenticate with.
|
||||
|
||||
Note:
|
||||
If username or password is None, it attempts auth with no credentials.
|
||||
"""
|
||||
params: dict[str, Any] = {}
|
||||
params["request"] = self.request_id
|
||||
|
||||
if not username or not password: # no credentials is valid option
|
||||
params["action"] = "default"
|
||||
else:
|
||||
params["action"] = "provideCredentials"
|
||||
params["credentials"] = {"type": "password", "username": username, "password": password}
|
||||
|
||||
self.network.conn.execute(command_builder("network.continueWithAuth", params))
|
||||
+83
@@ -0,0 +1,83 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
|
||||
|
||||
class PermissionState:
|
||||
"""Represents the possible permission states."""
|
||||
|
||||
GRANTED = "granted"
|
||||
DENIED = "denied"
|
||||
PROMPT = "prompt"
|
||||
|
||||
|
||||
class PermissionDescriptor:
|
||||
"""Represents a permission descriptor."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"name": self.name}
|
||||
|
||||
|
||||
class Permissions:
|
||||
"""BiDi implementation of the permissions module."""
|
||||
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
def set_permission(
|
||||
self,
|
||||
descriptor: str | PermissionDescriptor,
|
||||
state: str,
|
||||
origin: str,
|
||||
user_context: str | None = None,
|
||||
) -> None:
|
||||
"""Sets a permission state for a given permission descriptor.
|
||||
|
||||
Args:
|
||||
descriptor: The permission name (str) or PermissionDescriptor object.
|
||||
Examples: "geolocation", "camera", "microphone".
|
||||
state: The permission state (granted, denied, prompt).
|
||||
origin: The origin for which the permission is set.
|
||||
user_context: The user context id (optional).
|
||||
|
||||
Raises:
|
||||
ValueError: If the permission state is invalid.
|
||||
"""
|
||||
if state not in [PermissionState.GRANTED, PermissionState.DENIED, PermissionState.PROMPT]:
|
||||
valid_states = f"{PermissionState.GRANTED}, {PermissionState.DENIED}, {PermissionState.PROMPT}"
|
||||
raise ValueError(f"Invalid permission state. Must be one of: {valid_states}")
|
||||
|
||||
if isinstance(descriptor, str):
|
||||
permission_descriptor = PermissionDescriptor(descriptor)
|
||||
else:
|
||||
permission_descriptor = descriptor
|
||||
|
||||
params = {
|
||||
"descriptor": permission_descriptor.to_dict(),
|
||||
"state": state,
|
||||
"origin": origin,
|
||||
}
|
||||
|
||||
if user_context is not None:
|
||||
params["userContext"] = user_context
|
||||
|
||||
self.conn.execute(command_builder("permissions.setPermission", params))
|
||||
+547
@@ -0,0 +1,547 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import datetime
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
from selenium.webdriver.common.bidi.log import LogEntryAdded
|
||||
from selenium.webdriver.common.bidi.session import Session
|
||||
|
||||
|
||||
class ResultOwnership:
|
||||
"""Represents the possible result ownership types."""
|
||||
|
||||
NONE = "none"
|
||||
ROOT = "root"
|
||||
|
||||
|
||||
class RealmType:
|
||||
"""Represents the possible realm types."""
|
||||
|
||||
WINDOW = "window"
|
||||
DEDICATED_WORKER = "dedicated-worker"
|
||||
SHARED_WORKER = "shared-worker"
|
||||
SERVICE_WORKER = "service-worker"
|
||||
WORKER = "worker"
|
||||
PAINT_WORKLET = "paint-worklet"
|
||||
AUDIO_WORKLET = "audio-worklet"
|
||||
WORKLET = "worklet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealmInfo:
|
||||
"""Represents information about a realm."""
|
||||
|
||||
realm: str
|
||||
origin: str
|
||||
type: str
|
||||
context: str | None = None
|
||||
sandbox: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> "RealmInfo":
|
||||
"""Creates a RealmInfo instance from a dictionary.
|
||||
|
||||
Args:
|
||||
json: A dictionary containing the realm information.
|
||||
|
||||
Returns:
|
||||
RealmInfo: A new instance of RealmInfo.
|
||||
"""
|
||||
if "realm" not in json:
|
||||
raise ValueError("Missing required field 'realm' in RealmInfo")
|
||||
if "origin" not in json:
|
||||
raise ValueError("Missing required field 'origin' in RealmInfo")
|
||||
if "type" not in json:
|
||||
raise ValueError("Missing required field 'type' in RealmInfo")
|
||||
|
||||
return cls(
|
||||
realm=json["realm"],
|
||||
origin=json["origin"],
|
||||
type=json["type"],
|
||||
context=json.get("context"),
|
||||
sandbox=json.get("sandbox"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Source:
|
||||
"""Represents the source of a script message."""
|
||||
|
||||
realm: str
|
||||
context: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> "Source":
|
||||
"""Creates a Source instance from a dictionary.
|
||||
|
||||
Args:
|
||||
json: A dictionary containing the source information.
|
||||
|
||||
Returns:
|
||||
Source: A new instance of Source.
|
||||
"""
|
||||
if "realm" not in json:
|
||||
raise ValueError("Missing required field 'realm' in Source")
|
||||
|
||||
return cls(
|
||||
realm=json["realm"],
|
||||
context=json.get("context"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateResult:
|
||||
"""Represents the result of script evaluation."""
|
||||
|
||||
type: str
|
||||
realm: str
|
||||
result: dict | None = None
|
||||
exception_details: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> "EvaluateResult":
|
||||
"""Creates an EvaluateResult instance from a dictionary.
|
||||
|
||||
Args:
|
||||
json: A dictionary containing the evaluation result.
|
||||
|
||||
Returns:
|
||||
EvaluateResult: A new instance of EvaluateResult.
|
||||
"""
|
||||
if "realm" not in json:
|
||||
raise ValueError("Missing required field 'realm' in EvaluateResult")
|
||||
if "type" not in json:
|
||||
raise ValueError("Missing required field 'type' in EvaluateResult")
|
||||
|
||||
return cls(
|
||||
type=json["type"],
|
||||
realm=json["realm"],
|
||||
result=json.get("result"),
|
||||
exception_details=json.get("exceptionDetails"),
|
||||
)
|
||||
|
||||
|
||||
class ScriptMessage:
|
||||
"""Represents a script message event."""
|
||||
|
||||
event_class = "script.message"
|
||||
|
||||
def __init__(self, channel: str, data: dict, source: Source):
|
||||
self.channel = channel
|
||||
self.data = data
|
||||
self.source = source
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> "ScriptMessage":
|
||||
"""Creates a ScriptMessage instance from a dictionary.
|
||||
|
||||
Args:
|
||||
json: A dictionary containing the script message.
|
||||
|
||||
Returns:
|
||||
ScriptMessage: A new instance of ScriptMessage.
|
||||
"""
|
||||
if "channel" not in json:
|
||||
raise ValueError("Missing required field 'channel' in ScriptMessage")
|
||||
if "data" not in json:
|
||||
raise ValueError("Missing required field 'data' in ScriptMessage")
|
||||
if "source" not in json:
|
||||
raise ValueError("Missing required field 'source' in ScriptMessage")
|
||||
|
||||
return cls(
|
||||
channel=json["channel"],
|
||||
data=json["data"],
|
||||
source=Source.from_json(json["source"]),
|
||||
)
|
||||
|
||||
|
||||
class RealmCreated:
|
||||
"""Represents a realm created event."""
|
||||
|
||||
event_class = "script.realmCreated"
|
||||
|
||||
def __init__(self, realm_info: RealmInfo):
|
||||
self.realm_info = realm_info
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> "RealmCreated":
|
||||
"""Creates a RealmCreated instance from a dictionary.
|
||||
|
||||
Args:
|
||||
json: A dictionary containing the realm created event.
|
||||
|
||||
Returns:
|
||||
RealmCreated: A new instance of RealmCreated.
|
||||
"""
|
||||
return cls(realm_info=RealmInfo.from_json(json))
|
||||
|
||||
|
||||
class RealmDestroyed:
|
||||
"""Represents a realm destroyed event."""
|
||||
|
||||
event_class = "script.realmDestroyed"
|
||||
|
||||
def __init__(self, realm: str):
|
||||
self.realm = realm
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict[str, Any]) -> "RealmDestroyed":
|
||||
"""Creates a RealmDestroyed instance from a dictionary.
|
||||
|
||||
Args:
|
||||
json: A dictionary containing the realm destroyed event.
|
||||
|
||||
Returns:
|
||||
RealmDestroyed: A new instance of RealmDestroyed.
|
||||
"""
|
||||
if "realm" not in json:
|
||||
raise ValueError("Missing required field 'realm' in RealmDestroyed")
|
||||
|
||||
return cls(realm=json["realm"])
|
||||
|
||||
|
||||
class Script:
|
||||
"""BiDi implementation of the script module."""
|
||||
|
||||
EVENTS = {
|
||||
"message": "script.message",
|
||||
"realm_created": "script.realmCreated",
|
||||
"realm_destroyed": "script.realmDestroyed",
|
||||
}
|
||||
|
||||
def __init__(self, conn, driver=None):
|
||||
self.conn = conn
|
||||
self.driver = driver
|
||||
self.log_entry_subscribed = False
|
||||
self.subscriptions = {}
|
||||
self.callbacks = {}
|
||||
|
||||
# High-level APIs for SCRIPT module
|
||||
|
||||
def add_console_message_handler(self, handler):
|
||||
self._subscribe_to_log_entries()
|
||||
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler))
|
||||
|
||||
def add_javascript_error_handler(self, handler):
|
||||
self._subscribe_to_log_entries()
|
||||
return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler))
|
||||
|
||||
def remove_console_message_handler(self, id):
|
||||
self.conn.remove_callback(LogEntryAdded, id)
|
||||
self._unsubscribe_from_log_entries()
|
||||
|
||||
remove_javascript_error_handler = remove_console_message_handler
|
||||
|
||||
def pin(self, script: str) -> str:
|
||||
"""Pins a script to the current browsing context.
|
||||
|
||||
Args:
|
||||
script: The script to pin.
|
||||
|
||||
Returns:
|
||||
str: The ID of the pinned script.
|
||||
"""
|
||||
return self._add_preload_script(script)
|
||||
|
||||
def unpin(self, script_id: str) -> None:
|
||||
"""Unpins a script from the current browsing context.
|
||||
|
||||
Args:
|
||||
script_id: The ID of the pinned script to unpin.
|
||||
"""
|
||||
self._remove_preload_script(script_id)
|
||||
|
||||
def execute(self, script: str, *args) -> dict:
|
||||
"""Executes a script in the current browsing context.
|
||||
|
||||
Args:
|
||||
script: The script function to execute.
|
||||
*args: Arguments to pass to the script function.
|
||||
|
||||
Returns:
|
||||
dict: The result value from the script execution.
|
||||
|
||||
Raises:
|
||||
WebDriverException: If the script execution fails.
|
||||
"""
|
||||
if self.driver is None:
|
||||
raise WebDriverException("Driver reference is required for script execution")
|
||||
browsing_context_id = self.driver.current_window_handle
|
||||
|
||||
# Convert arguments to the format expected by BiDi call_function (LocalValue Type)
|
||||
arguments = []
|
||||
for arg in args:
|
||||
arguments.append(self.__convert_to_local_value(arg))
|
||||
|
||||
target = {"context": browsing_context_id}
|
||||
|
||||
result = self._call_function(
|
||||
function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None
|
||||
)
|
||||
|
||||
if result.type == "success":
|
||||
return result.result if result.result is not None else {}
|
||||
else:
|
||||
error_message = "Error while executing script"
|
||||
if result.exception_details:
|
||||
if "text" in result.exception_details:
|
||||
error_message += f": {result.exception_details['text']}"
|
||||
elif "message" in result.exception_details:
|
||||
error_message += f": {result.exception_details['message']}"
|
||||
|
||||
raise WebDriverException(error_message)
|
||||
|
||||
def __convert_to_local_value(self, value) -> dict:
|
||||
"""Converts a Python value to BiDi LocalValue format."""
|
||||
if value is None:
|
||||
return {"type": "null"}
|
||||
elif isinstance(value, bool):
|
||||
return {"type": "boolean", "value": value}
|
||||
elif isinstance(value, (int, float)):
|
||||
if isinstance(value, float):
|
||||
if math.isnan(value):
|
||||
return {"type": "number", "value": "NaN"}
|
||||
elif math.isinf(value):
|
||||
if value > 0:
|
||||
return {"type": "number", "value": "Infinity"}
|
||||
else:
|
||||
return {"type": "number", "value": "-Infinity"}
|
||||
elif value == 0.0 and math.copysign(1.0, value) < 0:
|
||||
return {"type": "number", "value": "-0"}
|
||||
|
||||
JS_MAX_SAFE_INTEGER = 9007199254740991
|
||||
if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER):
|
||||
return {"type": "bigint", "value": str(value)}
|
||||
|
||||
return {"type": "number", "value": value}
|
||||
|
||||
elif isinstance(value, str):
|
||||
return {"type": "string", "value": value}
|
||||
elif isinstance(value, datetime.datetime):
|
||||
# Convert Python datetime to JavaScript Date (ISO 8601 format)
|
||||
return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()}
|
||||
elif isinstance(value, datetime.date):
|
||||
# Convert Python date to JavaScript Date
|
||||
dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc)
|
||||
return {"type": "date", "value": dt.isoformat()}
|
||||
elif isinstance(value, set):
|
||||
return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]}
|
||||
elif isinstance(value, dict):
|
||||
return {
|
||||
"type": "object",
|
||||
"value": [
|
||||
[self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items()
|
||||
],
|
||||
}
|
||||
else:
|
||||
# For other types, convert to string
|
||||
return {"type": "string", "value": str(value)}
|
||||
|
||||
# low-level APIs for script module
|
||||
def _add_preload_script(
|
||||
self,
|
||||
function_declaration: str,
|
||||
arguments: list[dict[str, Any]] | None = None,
|
||||
contexts: list[str] | None = None,
|
||||
user_contexts: list[str] | None = None,
|
||||
sandbox: str | None = None,
|
||||
) -> str:
|
||||
"""Adds a preload script.
|
||||
|
||||
Args:
|
||||
function_declaration: The function declaration to preload.
|
||||
arguments: The arguments to pass to the function.
|
||||
contexts: The browsing context IDs to apply the script to.
|
||||
user_contexts: The user context IDs to apply the script to.
|
||||
sandbox: The sandbox name to apply the script to.
|
||||
|
||||
Returns:
|
||||
str: The preload script ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If both contexts and user_contexts are provided.
|
||||
"""
|
||||
if contexts is not None and user_contexts is not None:
|
||||
raise ValueError("Cannot specify both contexts and user_contexts")
|
||||
|
||||
params: dict[str, Any] = {"functionDeclaration": function_declaration}
|
||||
|
||||
if arguments is not None:
|
||||
params["arguments"] = arguments
|
||||
if contexts is not None:
|
||||
params["contexts"] = contexts
|
||||
if user_contexts is not None:
|
||||
params["userContexts"] = user_contexts
|
||||
if sandbox is not None:
|
||||
params["sandbox"] = sandbox
|
||||
|
||||
result = self.conn.execute(command_builder("script.addPreloadScript", params))
|
||||
return result["script"]
|
||||
|
||||
def _remove_preload_script(self, script_id: str) -> None:
|
||||
"""Removes a preload script.
|
||||
|
||||
Args:
|
||||
script_id: The preload script ID to remove.
|
||||
"""
|
||||
params = {"script": script_id}
|
||||
self.conn.execute(command_builder("script.removePreloadScript", params))
|
||||
|
||||
def _disown(self, handles: list[str], target: dict) -> None:
|
||||
"""Disowns the given handles.
|
||||
|
||||
Args:
|
||||
handles: The handles to disown.
|
||||
target: The target realm or context.
|
||||
"""
|
||||
params = {
|
||||
"handles": handles,
|
||||
"target": target,
|
||||
}
|
||||
self.conn.execute(command_builder("script.disown", params))
|
||||
|
||||
def _call_function(
|
||||
self,
|
||||
function_declaration: str,
|
||||
await_promise: bool,
|
||||
target: dict,
|
||||
arguments: list[dict] | None = None,
|
||||
result_ownership: str | None = None,
|
||||
serialization_options: dict | None = None,
|
||||
this: dict | None = None,
|
||||
user_activation: bool = False,
|
||||
) -> EvaluateResult:
|
||||
"""Calls a provided function with given arguments in a given realm.
|
||||
|
||||
Args:
|
||||
function_declaration: The function declaration to call.
|
||||
await_promise: Whether to await promise resolution.
|
||||
target: The target realm or context.
|
||||
arguments: The arguments to pass to the function.
|
||||
result_ownership: The result ownership type.
|
||||
serialization_options: The serialization options.
|
||||
this: The 'this' value for the function call.
|
||||
user_activation: Whether to trigger user activation.
|
||||
|
||||
Returns:
|
||||
EvaluateResult: The result of the function call.
|
||||
"""
|
||||
params = {
|
||||
"functionDeclaration": function_declaration,
|
||||
"awaitPromise": await_promise,
|
||||
"target": target,
|
||||
"userActivation": user_activation,
|
||||
}
|
||||
|
||||
if arguments is not None:
|
||||
params["arguments"] = arguments
|
||||
if result_ownership is not None:
|
||||
params["resultOwnership"] = result_ownership
|
||||
if serialization_options is not None:
|
||||
params["serializationOptions"] = serialization_options
|
||||
if this is not None:
|
||||
params["this"] = this
|
||||
|
||||
result = self.conn.execute(command_builder("script.callFunction", params))
|
||||
return EvaluateResult.from_json(result)
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
expression: str,
|
||||
target: dict,
|
||||
await_promise: bool,
|
||||
result_ownership: str | None = None,
|
||||
serialization_options: dict | None = None,
|
||||
user_activation: bool = False,
|
||||
) -> EvaluateResult:
|
||||
"""Evaluates a provided script in a given realm.
|
||||
|
||||
Args:
|
||||
expression: The script expression to evaluate.
|
||||
target: The target realm or context.
|
||||
await_promise: Whether to await promise resolution.
|
||||
result_ownership: The result ownership type.
|
||||
serialization_options: The serialization options.
|
||||
user_activation: Whether to trigger user activation.
|
||||
|
||||
Returns:
|
||||
EvaluateResult: The result of the script evaluation.
|
||||
"""
|
||||
params = {
|
||||
"expression": expression,
|
||||
"target": target,
|
||||
"awaitPromise": await_promise,
|
||||
"userActivation": user_activation,
|
||||
}
|
||||
|
||||
if result_ownership is not None:
|
||||
params["resultOwnership"] = result_ownership
|
||||
if serialization_options is not None:
|
||||
params["serializationOptions"] = serialization_options
|
||||
|
||||
result = self.conn.execute(command_builder("script.evaluate", params))
|
||||
return EvaluateResult.from_json(result)
|
||||
|
||||
def _get_realms(
|
||||
self,
|
||||
context: str | None = None,
|
||||
type: str | None = None,
|
||||
) -> list[RealmInfo]:
|
||||
"""Returns a list of all realms, optionally filtered.
|
||||
|
||||
Args:
|
||||
context: The browsing context ID to filter by.
|
||||
type: The realm type to filter by.
|
||||
|
||||
Returns:
|
||||
List[RealmInfo]: A list of realm information.
|
||||
"""
|
||||
params = {}
|
||||
|
||||
if context is not None:
|
||||
params["context"] = context
|
||||
if type is not None:
|
||||
params["type"] = type
|
||||
|
||||
result = self.conn.execute(command_builder("script.getRealms", params))
|
||||
return [RealmInfo.from_json(realm) for realm in result["realms"]]
|
||||
|
||||
def _subscribe_to_log_entries(self):
|
||||
if not self.log_entry_subscribed:
|
||||
session = Session(self.conn)
|
||||
self.conn.execute(session.subscribe(LogEntryAdded.event_class))
|
||||
self.log_entry_subscribed = True
|
||||
|
||||
def _unsubscribe_from_log_entries(self):
|
||||
if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks:
|
||||
session = Session(self.conn)
|
||||
self.conn.execute(session.unsubscribe(LogEntryAdded.event_class))
|
||||
self.log_entry_subscribed = False
|
||||
|
||||
def _handle_log_entry(self, type, handler):
|
||||
def _handle_log_entry(log_entry):
|
||||
if log_entry.type_ == type:
|
||||
handler(log_entry)
|
||||
|
||||
return _handle_log_entry
|
||||
+134
@@ -0,0 +1,134 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
|
||||
|
||||
class UserPromptHandlerType:
|
||||
"""Represents the behavior of the user prompt handler."""
|
||||
|
||||
ACCEPT = "accept"
|
||||
DISMISS = "dismiss"
|
||||
IGNORE = "ignore"
|
||||
|
||||
VALID_TYPES = {ACCEPT, DISMISS, IGNORE}
|
||||
|
||||
|
||||
class UserPromptHandler:
|
||||
"""Represents the configuration of the user prompt handler."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alert: str | None = None,
|
||||
before_unload: str | None = None,
|
||||
confirm: str | None = None,
|
||||
default: str | None = None,
|
||||
file: str | None = None,
|
||||
prompt: str | None = None,
|
||||
):
|
||||
"""Initialize UserPromptHandler.
|
||||
|
||||
Args:
|
||||
alert: Handler type for alert prompts.
|
||||
before_unload: Handler type for beforeUnload prompts.
|
||||
confirm: Handler type for confirm prompts.
|
||||
default: Default handler type for all prompts.
|
||||
file: Handler type for file picker prompts.
|
||||
prompt: Handler type for prompt dialogs.
|
||||
|
||||
Raises:
|
||||
ValueError: If any handler type is not valid.
|
||||
"""
|
||||
for field_name, value in [
|
||||
("alert", alert),
|
||||
("before_unload", before_unload),
|
||||
("confirm", confirm),
|
||||
("default", default),
|
||||
("file", file),
|
||||
("prompt", prompt),
|
||||
]:
|
||||
if value is not None and value not in UserPromptHandlerType.VALID_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid {field_name} handler type: {value}. Must be one of {UserPromptHandlerType.VALID_TYPES}"
|
||||
)
|
||||
|
||||
self.alert = alert
|
||||
self.before_unload = before_unload
|
||||
self.confirm = confirm
|
||||
self.default = default
|
||||
self.file = file
|
||||
self.prompt = prompt
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
"""Convert the UserPromptHandler to a dictionary for BiDi protocol.
|
||||
|
||||
Returns:
|
||||
Dictionary representation suitable for BiDi protocol.
|
||||
"""
|
||||
field_mapping = {
|
||||
"alert": "alert",
|
||||
"before_unload": "beforeUnload",
|
||||
"confirm": "confirm",
|
||||
"default": "default",
|
||||
"file": "file",
|
||||
"prompt": "prompt",
|
||||
}
|
||||
|
||||
result = {}
|
||||
for attr_name, dict_key in field_mapping.items():
|
||||
value = getattr(self, attr_name)
|
||||
if value is not None:
|
||||
result[dict_key] = value
|
||||
return result
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
def subscribe(self, *events, browsing_contexts=None):
|
||||
params = {
|
||||
"events": events,
|
||||
}
|
||||
if browsing_contexts is None:
|
||||
browsing_contexts = []
|
||||
if browsing_contexts:
|
||||
params["browsingContexts"] = browsing_contexts
|
||||
return command_builder("session.subscribe", params)
|
||||
|
||||
def unsubscribe(self, *events, browsing_contexts=None):
|
||||
params = {
|
||||
"events": events,
|
||||
}
|
||||
if browsing_contexts is None:
|
||||
browsing_contexts = []
|
||||
if browsing_contexts:
|
||||
params["browsingContexts"] = browsing_contexts
|
||||
return command_builder("session.unsubscribe", params)
|
||||
|
||||
def status(self):
|
||||
"""The session.status command returns information about the remote end's readiness.
|
||||
|
||||
Returns information about the remote end's readiness to create new sessions
|
||||
and may include implementation-specific metadata.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the ready state (bool), message (str) and metadata.
|
||||
"""
|
||||
cmd = command_builder("session.status", {})
|
||||
return self.conn.execute(cmd)
|
||||
+413
@@ -0,0 +1,413 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
|
||||
|
||||
|
||||
class SameSite:
|
||||
"""Represents the possible same site values for cookies."""
|
||||
|
||||
STRICT = "strict"
|
||||
LAX = "lax"
|
||||
NONE = "none"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class BytesValue:
|
||||
"""Represents a bytes value."""
|
||||
|
||||
TYPE_BASE64 = "base64"
|
||||
TYPE_STRING = "string"
|
||||
|
||||
def __init__(self, type: str, value: str):
|
||||
self.type = type
|
||||
self.value = value
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
"""Converts the BytesValue to a dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the BytesValue.
|
||||
"""
|
||||
return {"type": self.type, "value": self.value}
|
||||
|
||||
|
||||
class Cookie:
|
||||
"""Represents a cookie."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
value: BytesValue,
|
||||
domain: str,
|
||||
path: str | None = None,
|
||||
size: int | None = None,
|
||||
http_only: bool | None = None,
|
||||
secure: bool | None = None,
|
||||
same_site: str | None = None,
|
||||
expiry: int | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.domain = domain
|
||||
self.path = path
|
||||
self.size = size
|
||||
self.http_only = http_only
|
||||
self.secure = secure
|
||||
self.same_site = same_site
|
||||
self.expiry = expiry
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Cookie:
|
||||
"""Creates a Cookie instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the cookie information.
|
||||
|
||||
Returns:
|
||||
A new instance of Cookie.
|
||||
"""
|
||||
# Validation for empty strings
|
||||
name = data.get("name")
|
||||
if not name:
|
||||
raise ValueError("name is required and cannot be empty")
|
||||
domain = data.get("domain")
|
||||
if not domain:
|
||||
raise ValueError("domain is required and cannot be empty")
|
||||
|
||||
value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value"))
|
||||
return cls(
|
||||
name=str(name),
|
||||
value=value,
|
||||
domain=str(domain),
|
||||
path=data.get("path"),
|
||||
size=data.get("size"),
|
||||
http_only=data.get("httpOnly"),
|
||||
secure=data.get("secure"),
|
||||
same_site=data.get("sameSite"),
|
||||
expiry=data.get("expiry"),
|
||||
)
|
||||
|
||||
|
||||
class CookieFilter:
|
||||
"""Represents a filter for cookies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
value: BytesValue | None = None,
|
||||
domain: str | None = None,
|
||||
path: str | None = None,
|
||||
size: int | None = None,
|
||||
http_only: bool | None = None,
|
||||
secure: bool | None = None,
|
||||
same_site: str | None = None,
|
||||
expiry: int | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.domain = domain
|
||||
self.path = path
|
||||
self.size = size
|
||||
self.http_only = http_only
|
||||
self.secure = secure
|
||||
self.same_site = same_site
|
||||
self.expiry = expiry
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Converts the CookieFilter to a dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the CookieFilter.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
if self.name is not None:
|
||||
result["name"] = self.name
|
||||
if self.value is not None:
|
||||
result["value"] = self.value.to_dict()
|
||||
if self.domain is not None:
|
||||
result["domain"] = self.domain
|
||||
if self.path is not None:
|
||||
result["path"] = self.path
|
||||
if self.size is not None:
|
||||
result["size"] = self.size
|
||||
if self.http_only is not None:
|
||||
result["httpOnly"] = self.http_only
|
||||
if self.secure is not None:
|
||||
result["secure"] = self.secure
|
||||
if self.same_site is not None:
|
||||
result["sameSite"] = self.same_site
|
||||
if self.expiry is not None:
|
||||
result["expiry"] = self.expiry
|
||||
return result
|
||||
|
||||
|
||||
class PartitionKey:
|
||||
"""Represents a storage partition key."""
|
||||
|
||||
def __init__(self, user_context: str | None = None, source_origin: str | None = None):
|
||||
self.user_context = user_context
|
||||
self.source_origin = source_origin
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> PartitionKey:
|
||||
"""Creates a PartitionKey instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the partition key information.
|
||||
|
||||
Returns:
|
||||
A new instance of PartitionKey.
|
||||
"""
|
||||
return cls(
|
||||
user_context=data.get("userContext"),
|
||||
source_origin=data.get("sourceOrigin"),
|
||||
)
|
||||
|
||||
|
||||
class BrowsingContextPartitionDescriptor:
|
||||
"""Represents a browsing context partition descriptor."""
|
||||
|
||||
def __init__(self, context: str):
|
||||
self.type = "context"
|
||||
self.context = context
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
"""Converts the BrowsingContextPartitionDescriptor to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the BrowsingContextPartitionDescriptor.
|
||||
"""
|
||||
return {"type": self.type, "context": self.context}
|
||||
|
||||
|
||||
class StorageKeyPartitionDescriptor:
|
||||
"""Represents a storage key partition descriptor."""
|
||||
|
||||
def __init__(self, user_context: str | None = None, source_origin: str | None = None):
|
||||
self.type = "storageKey"
|
||||
self.user_context = user_context
|
||||
self.source_origin = source_origin
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
"""Converts the StorageKeyPartitionDescriptor to a dictionary.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary representation of the StorageKeyPartitionDescriptor.
|
||||
"""
|
||||
result = {"type": self.type}
|
||||
if self.user_context is not None:
|
||||
result["userContext"] = self.user_context
|
||||
if self.source_origin is not None:
|
||||
result["sourceOrigin"] = self.source_origin
|
||||
return result
|
||||
|
||||
|
||||
class PartialCookie:
|
||||
"""Represents a partial cookie for setting."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
value: BytesValue,
|
||||
domain: str,
|
||||
path: str | None = None,
|
||||
http_only: bool | None = None,
|
||||
secure: bool | None = None,
|
||||
same_site: str | None = None,
|
||||
expiry: int | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.domain = domain
|
||||
self.path = path
|
||||
self.http_only = http_only
|
||||
self.secure = secure
|
||||
self.same_site = same_site
|
||||
self.expiry = expiry
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Converts the PartialCookie to a dictionary.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Dict: A dictionary representation of the PartialCookie.
|
||||
"""
|
||||
result: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"value": self.value.to_dict(),
|
||||
"domain": self.domain,
|
||||
}
|
||||
if self.path is not None:
|
||||
result["path"] = self.path
|
||||
if self.http_only is not None:
|
||||
result["httpOnly"] = self.http_only
|
||||
if self.secure is not None:
|
||||
result["secure"] = self.secure
|
||||
if self.same_site is not None:
|
||||
result["sameSite"] = self.same_site
|
||||
if self.expiry is not None:
|
||||
result["expiry"] = self.expiry
|
||||
return result
|
||||
|
||||
|
||||
class GetCookiesResult:
|
||||
"""Represents the result of a getCookies command."""
|
||||
|
||||
def __init__(self, cookies: list[Cookie], partition_key: PartitionKey):
|
||||
self.cookies = cookies
|
||||
self.partition_key = partition_key
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> GetCookiesResult:
|
||||
"""Creates a GetCookiesResult instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the get cookies result information.
|
||||
|
||||
Returns:
|
||||
A new instance of GetCookiesResult.
|
||||
"""
|
||||
cookies = [Cookie.from_dict(cookie) for cookie in data.get("cookies", [])]
|
||||
partition_key = PartitionKey.from_dict(data.get("partitionKey", {}))
|
||||
return cls(cookies=cookies, partition_key=partition_key)
|
||||
|
||||
|
||||
class SetCookieResult:
|
||||
"""Represents the result of a setCookie command."""
|
||||
|
||||
def __init__(self, partition_key: PartitionKey):
|
||||
self.partition_key = partition_key
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> SetCookieResult:
|
||||
"""Creates a SetCookieResult instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the set cookie result information.
|
||||
|
||||
Returns:
|
||||
A new instance of SetCookieResult.
|
||||
"""
|
||||
partition_key = PartitionKey.from_dict(data.get("partitionKey", {}))
|
||||
return cls(partition_key=partition_key)
|
||||
|
||||
|
||||
class DeleteCookiesResult:
|
||||
"""Represents the result of a deleteCookies command."""
|
||||
|
||||
def __init__(self, partition_key: PartitionKey):
|
||||
self.partition_key = partition_key
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DeleteCookiesResult:
|
||||
"""Creates a DeleteCookiesResult instance from a dictionary.
|
||||
|
||||
Args:
|
||||
data: A dictionary containing the delete cookies result information.
|
||||
|
||||
Returns:
|
||||
A new instance of DeleteCookiesResult.
|
||||
"""
|
||||
partition_key = PartitionKey.from_dict(data.get("partitionKey", {}))
|
||||
return cls(partition_key=partition_key)
|
||||
|
||||
|
||||
class Storage:
|
||||
"""BiDi implementation of the storage module."""
|
||||
|
||||
def __init__(self, conn: WebSocketConnection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
def get_cookies(
|
||||
self,
|
||||
filter: CookieFilter | None = None,
|
||||
partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None,
|
||||
) -> GetCookiesResult:
|
||||
"""Gets cookies matching the specified filter.
|
||||
|
||||
Args:
|
||||
filter: Optional filter to specify which cookies to retrieve.
|
||||
partition: Optional partition key to limit the scope of the operation.
|
||||
|
||||
Returns:
|
||||
A GetCookiesResult containing the cookies and partition key.
|
||||
|
||||
Example:
|
||||
result = await storage.get_cookies(
|
||||
filter=CookieFilter(name="sessionId"),
|
||||
partition=PartitionKey(...)
|
||||
)
|
||||
"""
|
||||
params = {}
|
||||
if filter is not None:
|
||||
params["filter"] = filter.to_dict()
|
||||
if partition is not None:
|
||||
params["partition"] = partition.to_dict()
|
||||
|
||||
result = self.conn.execute(command_builder("storage.getCookies", params))
|
||||
return GetCookiesResult.from_dict(result)
|
||||
|
||||
def set_cookie(
|
||||
self,
|
||||
cookie: PartialCookie,
|
||||
partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None,
|
||||
) -> SetCookieResult:
|
||||
"""Sets a cookie in the browser.
|
||||
|
||||
Args:
|
||||
cookie: The cookie to set.
|
||||
partition: Optional partition descriptor.
|
||||
|
||||
Returns:
|
||||
The result of the set cookie command.
|
||||
"""
|
||||
params = {"cookie": cookie.to_dict()}
|
||||
if partition is not None:
|
||||
params["partition"] = partition.to_dict()
|
||||
|
||||
result = self.conn.execute(command_builder("storage.setCookie", params))
|
||||
return SetCookieResult.from_dict(result)
|
||||
|
||||
def delete_cookies(
|
||||
self,
|
||||
filter: CookieFilter | None = None,
|
||||
partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None,
|
||||
) -> DeleteCookiesResult:
|
||||
"""Deletes cookies that match the given parameters.
|
||||
|
||||
Args:
|
||||
filter: Optional filter to match cookies to delete.
|
||||
partition: Optional partition descriptor.
|
||||
|
||||
Returns:
|
||||
The result of the delete cookies command.
|
||||
"""
|
||||
params = {}
|
||||
if filter is not None:
|
||||
params["filter"] = filter.to_dict()
|
||||
if partition is not None:
|
||||
params["partition"] = partition.to_dict()
|
||||
|
||||
result = self.conn.execute(command_builder("storage.deleteCookies", params))
|
||||
return DeleteCookiesResult.from_dict(result)
|
||||
+78
@@ -0,0 +1,78 @@
|
||||
# Licensed to the Software Freedom Conservancy (SFC) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The SFC licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver.common.bidi.common import command_builder
|
||||
|
||||
|
||||
class WebExtension:
|
||||
"""BiDi implementation of the webExtension module."""
|
||||
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
def install(self, path=None, archive_path=None, base64_value=None) -> dict:
|
||||
"""Installs a web extension in the remote end.
|
||||
|
||||
You must provide exactly one of the parameters.
|
||||
|
||||
Args:
|
||||
path: Path to an extension directory.
|
||||
archive_path: Path to an extension archive file.
|
||||
base64_value: Base64 encoded string of the extension archive.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the extension ID.
|
||||
"""
|
||||
if sum(x is not None for x in (path, archive_path, base64_value)) != 1:
|
||||
raise ValueError("Exactly one of path, archive_path, or base64_value must be provided")
|
||||
|
||||
if path is not None:
|
||||
extension_data = {"type": "path", "path": path}
|
||||
elif archive_path is not None:
|
||||
extension_data = {"type": "archivePath", "path": archive_path}
|
||||
elif base64_value is not None:
|
||||
extension_data = {"type": "base64", "value": base64_value}
|
||||
|
||||
params = {"extensionData": extension_data}
|
||||
|
||||
try:
|
||||
result = self.conn.execute(command_builder("webExtension.install", params))
|
||||
return result
|
||||
except WebDriverException as e:
|
||||
if "Method not available" in str(e):
|
||||
raise WebDriverException(
|
||||
f"{e!s}. If you are using Chrome or Edge, add '--enable-unsafe-extension-debugging' "
|
||||
"and '--remote-debugging-pipe' arguments or set options.enable_webextensions = True"
|
||||
) from e
|
||||
raise
|
||||
|
||||
def uninstall(self, extension_id_or_result: str | dict) -> None:
|
||||
"""Uninstalls a web extension from the remote end.
|
||||
|
||||
Args:
|
||||
extension_id_or_result: Either the extension ID as a string or the result dictionary
|
||||
from a previous install() call containing the extension ID.
|
||||
"""
|
||||
if isinstance(extension_id_or_result, dict):
|
||||
extension_id = extension_id_or_result.get("extension")
|
||||
else:
|
||||
extension_id = extension_id_or_result
|
||||
|
||||
params = {"extension": extension_id}
|
||||
self.conn.execute(command_builder("webExtension.uninstall", params))
|
||||
Reference in New Issue
Block a user