Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feat/leverage-dom-for-retrieval'…
Browse files Browse the repository at this point in the history
… into feat/dom-interactive-navigation
  • Loading branch information
adeprez committed Jul 5, 2024
2 parents 3422e34 + fd4ed75 commit 933ada7
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 75 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ screenshot*
node_modules
extension_chrome/dist
extension_chrome/node_modules
.vscode

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
120 changes: 117 additions & 3 deletions lavague-core/lavague/core/base_driver.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from pathlib import Path
from typing import Any, Callable, Optional, Mapping
from typing import Any, Callable, Optional, Mapping, Dict, Set
from abc import ABC, abstractmethod
from lavague.core.utilities.format_utils import (
extract_code_from_funct,
extract_imports_from_lines,
)
from enum import Enum
import time
from datetime import datetime
import hashlib


class InteractionType(Enum):
CLICK = "click"
HOVER = "hover"
SCROLL = "scroll"
TYPE = "type"


PossibleInteractionsByXpath = Dict[str, Set[InteractionType]]


class BaseDriver(ABC):
def __init__(self, url: Optional[str], init_function: Optional[Callable[[], Any]]):
"""Init the driver with the init funtion, and then go to the desired url"""
Expand Down Expand Up @@ -128,8 +139,8 @@ def get_screenshots_whole_page(self) -> list[str]:
return screenshot_paths

@abstractmethod
def check_visibility(self, xpath: str) -> bool:
"""Check an element visibility by its xpath"""
def get_possible_interactions(self) -> PossibleInteractionsByXpath:
"""Get elements that can be interacted with as a dictionary mapped by xpath"""
pass

@abstractmethod
Expand Down Expand Up @@ -220,3 +231,106 @@ def get_current_screenshot_folder(self) -> Path:
@abstractmethod
def get_screenshot_as_png(self) -> bytes:
pass

@abstractmethod
def resolve_xpath(self, xpath: str):
pass


JS_SETUP_GET_EVENTS = """
(function() {
Element.prototype._addEventListener = Element.prototype.addEventListener;
Element.prototype.addEventListener = function(a,b,c) {
this._addEventListener(a,b,c);
if(!this.eventListenerList) this.eventListenerList = {};
if(!this.eventListenerList[a]) this.eventListenerList[a] = [];
this.eventListenerList[a].push(b);
};
Element.prototype._removeEventListener = Element.prototype.removeEventListener;
Element.prototype.removeEventListener = function(a, b, c) {
this._removeEventListener(a, b, c);
if(this.eventListenerList && this.eventListenerList[a]) {
const index = this.eventListenerList[a].indexOf(b);
if (index > -1) {
this.eventListenerList[a].splice(index, 1);
if(!this.eventListenerList[a].length) {
delete this.eventListenerList[a];
}
}
}
};
if (!window.getEventListeners) {
window.getEventListeners = function(e) {
return (e && e.eventListenerList) || [];
}
}
})();"""

JS_GET_INTERACTIVES = """
return (function() {
function getInteractions(e) {
const tag = e.tagName.toLowerCase();
if (!e.checkVisibility() || e.hasAttribute('disabled') || e.hasAttribute('readonly') || e.getAttribute('aria-hidden') === 'true'
|| e.getAttribute('aria-disabled') === 'true' || (tag === 'input' && e.getAttribute('type') === 'hidden')) {
return [];
}
const style = getComputedStyle(e);
if (style.display === 'none' || style.visibility === 'hidden' || style.opacity === '0') {
return [];
}
const events = getEventListeners(e);
const role = e.getAttribute('role');
const clickableInputs = ['submit', 'checkbox', 'radio', 'color', 'file', 'image', 'reset'];
function hasEvent(n) {
return events[n]?.length || e.hasAttribute('on' + n);
}
const evts = [];
if (hasEvent('keydown') || hasEvent('keyup') || hasEvent('keypress') || hasEvent('keydown') || hasEvent('input') || e.isContentEditable
|| (
(tag === 'input' || tag === 'textarea' || role === 'searchbox' || role === 'input')
) && !clickableInputs.includes(e.getAttribute('type'))
) {
evts.push('TYPE');
}
if (tag === 'a' || tag === 'button' || role === 'button' || role === 'checkbox' || hasEvent('click') || hasEvent('mousedown') || hasEvent('mouseup')
|| hasEvent('dblclick') || style.cursor === 'pointer' || (tag === 'input' && clickableInputs.includes(e.getAttribute('type')) )
|| e.hasAttribute('aria-haspopup') || tag === 'select' || role === 'select') {
evts.push('CLICK');
}
if (hasEvent('mouseover')) {
evts.push('HOVER');
}
return evts;
}
const results = {};
function traverse(node, xpath) {
if (node.nodeType === Node.ELEMENT_NODE) {
const interactions = getInteractions(node);
if (interactions.length > 0) {
results[xpath] = interactions;
}
}
const countByTag = {};
for (let child = node.firstChild; child; child = child.nextSibling) {
const tag = child.nodeName.toLowerCase();
countByTag[tag] = (countByTag[tag] || 0) + 1;
let childXpath = xpath + '/' + tag;
if (countByTag[tag] > 1) {
childXpath += '[' + countByTag[tag] + ']';
}
if (tag === 'iframe') {
try {
traverse(child.contentWindow.document.body, childXpath + '/html/body');
} catch (e) {
console.error("iframe access blocked", child, e);
}
} else {
traverse(child, childXpath);
}
}
}
traverse(document.body, '/html/body');
return results;
})();
"""
3 changes: 1 addition & 2 deletions lavague-core/lavague/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def init_driver() -> WebDriver:


class SeleniumDriverForEval(SeleniumDriver):
def check_visibility(self, xpath: str) -> bool:
return True
pass


class Evaluator(ABC):
Expand Down
2 changes: 1 addition & 1 deletion lavague-core/lavague/core/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
n_attempts: int = 5,
logger: AgentLogger = None,
display: bool = False,
raise_on_error=False,
raise_on_error: bool = False,
):
if llm is None:
llm: BaseLLM = get_default_context().llm
Expand Down
37 changes: 28 additions & 9 deletions lavague-core/lavague/core/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations
from typing import List
from abc import ABC, abstractmethod
from bs4 import BeautifulSoup, NavigableString
import ast
from bs4 import BeautifulSoup
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core import Document
Expand All @@ -15,6 +14,7 @@
from lavague.core.utilities.format_utils import clean_html
from lavague.core.context import get_default_context


class BaseHtmlRetriever(ABC):
def __init__(
self, driver: BaseDriver, embedding: BaseEmbedding = None, top_k: int = 3
Expand Down Expand Up @@ -73,37 +73,56 @@ def _generate_xpath(self, element, path=""): # used to generate dict nodes
]
if len(siblings) > 1:
count = siblings.index(element) + 1
path = f"/{element.name}[{count}]{path}"
if count == 1:
path = f"/{element.name}{path}"
else:
path = f"/{element.name}[{count}]{path}"
else:
path = f"/{element.name}{path}"
return self._generate_xpath(element.parent, path)
def _add_xpath_attributes(self, html_content, xpath_prefix = ""):

def _add_xpath_attributes(self, html_content, xpath_prefix=""):
soup = BeautifulSoup(html_content, "html.parser")
for element in soup.find_all(True):
element["xpath"] = xpath_prefix + self._generate_xpath(element)
for iframe_tag in soup.find_all('iframe'):
for iframe_tag in soup.find_all("iframe"):
frame_xpath = self._generate_xpath(iframe_tag)
try:
self.driver.resolve_xpath(frame_xpath)
except Exception as e:
continue
frame_soup_str = self._add_xpath_attributes(self.driver.get_html(), xpath_prefix + frame_xpath)
frame_soup_str = self._add_xpath_attributes(
self.driver.get_html(), xpath_prefix + frame_xpath
)
iframe_tag.replace_with(frame_soup_str)
self.driver.driver.switch_to.parent_frame()
return str(soup)

def retrieve_html(self, query: QueryBundle) -> List[NodeWithScore]:
html = self._add_xpath_attributes(self.driver.get_html())
text_list = [html]
documents = [Document(text=t) for t in text_list]
documents = [Document(text=html)]
splitter = LangchainNodeParser(
lc_splitter=RecursiveCharacterTextSplitter.from_language(
language="html",
)
)
nodes = splitter.get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes, embed_model=self.embedding)
possible_interactions = self.driver.get_possible_interactions()

compatible_nodes = []
for node in nodes:
soup = BeautifulSoup(node.text, "html.parser")
for element in soup.find_all(True):
if len(possible_interactions.get(element.get("xpath", ""), set())) > 0:
compatible_nodes.append(node)
break

if len(compatible_nodes) == 0:
# no interactive node matches, let the retriever decide
compatible_nodes = nodes

index = VectorStoreIndex(compatible_nodes, embed_model=self.embedding)
retriever = BM25Retriever.from_defaults(
index=index, similarity_top_k=self.top_k
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ def get_screenshot_as_png(self) -> bytes:
def destroy(self) -> None:
pass

def check_visibility(self, xpath: str) -> bool:
return self.send_command_and_get_response_sync("is_visible", xpath)
# try:
# return self.driver.find_element(By.XPATH, xpath).is_displayed()
# except:
# return False

def get_highlighted_element(self, generated_code: str):
# local_scope = {"driver": self.get_driver()}
# assignment_code = keep_assignments(generated_code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import json
import os
from PIL import Image
from typing import Callable, Optional, Any, Mapping
from lavague.core.utilities.format_utils import (
extract_code_from_funct,
keep_assignments,
return_assigned_variables,
)
from typing import Callable, Optional, Any, Mapping, Iterable, Dict, List
from lavague.core.utilities.format_utils import extract_code_from_funct
from playwright.sync_api import Page, Locator
from lavague.core.base_driver import BaseDriver
from lavague.core.base_driver import (
BaseDriver,
JS_GET_INTERACTIVES,
PossibleInteractionsByXpath,
InteractionType,
)


class PlaywrightDriver(BaseDriver):
Expand All @@ -29,13 +30,15 @@ def __init__(
)
self.headless = headless
self.user_data_dir = user_data_dir
self.width = 1080
self.height = 1080
self.width = width
self.height = height
super().__init__(url, get_sync_playwright_page)

# Before modifying this function, check if your changes are compatible with code_for_init which parses this code
# these imports are necessary as they will be pasted to the output
def default_init_code(self) -> Page:
from lavague.core.base_driver import JS_SETUP_GET_EVENTS

try:
from playwright.sync_api import sync_playwright
except (ImportError, ModuleNotFoundError) as error:
Expand All @@ -45,13 +48,19 @@ def default_init_code(self) -> Page:
p = sync_playwright().__enter__()
user_agent = "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36"
if self.user_data_dir is None:
browser = p.chromium.launch(headless=self.headless)
browser = p.chromium.launch(
headless=self.headless,
args=["--disable-web-security", "--disable-site-isolation-trials"],
)
else:
browser = p.chromium.launch_persistent_context(
user_data_dir=self.user_data_dir,
headless=self.headless,
args=["--disable-web-security", "--disable-site-isolation-trials"],
)

context = browser.new_context(user_agent=user_agent)
context.add_init_script(JS_SETUP_GET_EVENTS)
page = context.new_page()
self.page = page
self.resize_driver(self.width, self.height)
Expand Down Expand Up @@ -138,17 +147,12 @@ def resolve_xpath(self, xpath) -> Locator:
element = self.resolve_xpath(after)
return element

def check_visibility(self, xpath: str) -> bool:
try:
locator = self.page.locator(f"xpath={xpath}")
return locator.is_visible() and locator.is_enabled()
except:
return False

def get_highlighted_element(self, generated_code: str):
elements = []

data = json.loads(generated_code)
if not isinstance(data, Iterable):
data = [data]
for item in data:
action_name = item["action"]["name"]
if action_name != "fail":
Expand All @@ -157,7 +161,7 @@ def get_highlighted_element(self, generated_code: str):
elements.append(elem)

if len(elements) == 0:
raise ValueError(f"No element found.")
raise ValueError("No element found.")

outputs = []
for element in elements:
Expand Down Expand Up @@ -249,10 +253,17 @@ def code_for_execute_script(self, js_code: str, *args) -> str:
return f"page.evaluate(\"(arguments) => {{{js_code}}}\", [{', '.join(str(arg) for arg in args)}])"

def scroll_up(self):
code = self.execute_script("window.scrollBy(0, -window.innerHeight);")
self.execute_script("window.scrollBy(0, -window.innerHeight);")

def scroll_down(self):
code = self.execute_script("window.scrollBy(0, window.innerHeight);")
self.execute_script("window.scrollBy(0, window.innerHeight);")

def get_possible_interactions(self) -> PossibleInteractionsByXpath:
exe: Dict[str, List[str]] = self.execute_script(JS_GET_INTERACTIVES)
res = dict()
for k, v in exe.items():
res[k] = set(InteractionType[i] for i in v)
return res

def get_capability(self) -> str:
return """
Expand Down
Loading

0 comments on commit 933ada7

Please sign in to comment.