Skip to content

Commit

Permalink
Add boolean type (eyurtsev#165)
Browse files Browse the repository at this point in the history
This PR adds a boolean type to Kor.
  • Loading branch information
eyurtsev authored May 24, 2023
1 parent d47fb0c commit e6266e8
Show file tree
Hide file tree
Showing 12 changed files with 2,056 additions and 1,938 deletions.
3 changes: 2 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ representation automatically.
.. autosummary::

Object
Number
Text
Number
Bool
Selection
Option

Expand Down
3 changes: 2 additions & 1 deletion kor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
create_extraction_chain,
extract_from_documents,
)
from .nodes import Number, Object, Option, Selection, Text
from .nodes import Bool, Number, Object, Option, Selection, Text
from .type_descriptors import (
BulletPointDescriptor,
TypeDescriptor,
Expand All @@ -22,6 +22,7 @@
"Extraction",
"from_pydantic",
"JSONEncoder",
"Bool",
"Number",
"Object",
"Option",
Expand Down
4 changes: 2 additions & 2 deletions kor/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pydantic import BaseModel

from .nodes import ExtractionSchemaNode, Number, Object, Option, Selection, Text
from .nodes import Bool, ExtractionSchemaNode, Number, Object, Option, Selection, Text
from .validators import PydanticValidator, Validator

# Not going to support dicts or lists since that requires recursive checks.
Expand Down Expand Up @@ -86,7 +86,7 @@ def _translate_pydantic_to_kor(
name=field_name,
)
elif issubclass(type_, bool):
attribute = Text(
attribute = Bool(
id=field_name,
examples=field_examples,
description=field_description,
Expand Down
17 changes: 15 additions & 2 deletions kor/documents/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import re
from typing import Tuple

import markdownify
from bs4 import BeautifulSoup
from langchain.schema import Document

from kor.documents.typedefs import AbstractDocumentProcessor
Expand All @@ -16,6 +14,13 @@

def _get_mini_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> str:
"""Clean up HTML tags."""
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"Please install BeautifulSoup to use the HTML document processor. "
"You can do so by running `pip install beautifulsoup4`."
)
# Parse the HTML document using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")

Expand All @@ -34,6 +39,14 @@ def _get_mini_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> s

def _clean_html(html: str, *, tags_to_remove: Tuple[str, ...] = tuple()) -> str:
"""Clean up HTML and convert to markdown using markdownify."""
try:
import markdownify
except ImportError:
raise ImportError(
"Please install markdownify to use the HTML document processor. "
"You can do so by running `pip install markdownify`."
)

html = _get_mini_html(html, tags_to_remove=tags_to_remove)
md = markdownify.markdownify(html)
return CONSECUTIVE_NEW_LINES.sub("\n\n", md).strip()
Expand Down
2 changes: 1 addition & 1 deletion kor/extraction/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
try: # Handle breaking change in langchain
from langchain.base_language import BaseLanguageModel
except ImportError:
from langchain.schema import BaseLanguageModel
from langchain.schema import BaseLanguageModel # type: ignore

from kor.encoders import Encoder, InputFormatter, initialize_encoder
from kor.extraction.typedefs import DocumentExtraction, Extraction
Expand Down
28 changes: 21 additions & 7 deletions kor/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Definitions of input elements."""
from __future__ import annotations

import abc
import copy
import re
Expand Down Expand Up @@ -31,30 +33,34 @@
class AbstractVisitor(Generic[T], abc.ABC):
"""An abstract visitor."""

def visit_text(self, node: "Text", **kwargs: Any) -> T:
def visit_text(self, node: Text, **kwargs: Any) -> T:
"""Visit text node."""
return self.visit_default(node, **kwargs)

def visit_number(self, node: "Number", **kwargs: Any) -> T:
def visit_number(self, node: Number, **kwargs: Any) -> T:
"""Visit text node."""
return self.visit_default(node, **kwargs)

def visit_object(self, node: "Object", **kwargs: Any) -> T:
def visit_object(self, node: Object, **kwargs: Any) -> T:
"""Visit object node."""
return self.visit_default(node, **kwargs)

def visit_selection(self, node: "Selection", **kwargs: Any) -> T:
def visit_selection(self, node: Selection, **kwargs: Any) -> T:
"""Visit selection node."""
return self.visit_default(node, **kwargs)

def visit_option(self, node: "Option", **kwargs: Any) -> T:
def visit_option(self, node: Option, **kwargs: Any) -> T:
"""Visit option node."""
return self.visit_default(node, **kwargs)

def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> T:
def visit_default(self, node: AbstractSchemaNode, **kwargs: Any) -> T:
"""Default node implementation."""
raise NotImplementedError()

def visit_bool(self, node: Bool, **kwargs: Any) -> T:
"""Visit bool node."""
return self.visit_default(node, **kwargs)


class AbstractSchemaNode(BaseModel):
"""Abstract schema node.
Expand Down Expand Up @@ -140,6 +146,14 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
return visitor.visit_text(self, **kwargs)


class Bool(ExtractionSchemaNode):
"""Built-in bool input."""

def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T:
"""Accept a visitor."""
return visitor.visit_bool(self, **kwargs)


class Option(AbstractSchemaNode):
"""Built-in option input must be part of a selection input."""

Expand Down Expand Up @@ -219,7 +233,7 @@ class Object(AbstractSchemaNode):
"""

attributes: Sequence[Union[ExtractionSchemaNode, Selection, "Object"]]
attributes: Sequence[Union[ExtractionSchemaNode, Selection, Object]]

examples: Sequence[
Tuple[
Expand Down
3 changes: 3 additions & 0 deletions kor/type_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kor.nodes import (
AbstractSchemaNode,
AbstractVisitor,
Bool,
Number,
Object,
Selection,
Expand Down Expand Up @@ -76,6 +77,8 @@ def visit_default(self, node: "AbstractSchemaNode", **kwargs: Any) -> List[str]:
finalized_type = "string"
elif isinstance(node, Number):
finalized_type = "number"
elif isinstance(node, Bool):
finalized_type = "boolean"
else:
raise NotImplementedError()

Expand Down
Loading

0 comments on commit e6266e8

Please sign in to comment.