Skip to content

Weighted Bayesian Network Multiclass Text Classification Model

License

Notifications You must be signed in to change notification settings

leonkozlowski/wbn

Repository files navigation

wbn

Documentation Status Updates

Weighted Bayesian Network Text Classification

Installation

From source

$ git clone https://github.com/leonkozlowski/wbn.git
$ cd wbn

$ python3.8 -m venv venv
$ pip install -e .

From build

$ pip install wbn

Usage

Building, training, and testing WBN

from sklearn.model_selection import train_test_split

# Import WBN
from wbn.classifier import WBN
from wbn.sample.datasets import load_pr_newswire


# Build the model
wbn = WBN()

# Load a sample dataset
pr_newswire = load_pr_newswire()

# Train/test split
x_train, x_test, y_train, y_test = train_test_split(
    pr_newswire.data, pr_newswire.target, test_size=0.2
)

# Fit the model
wbn.fit(x_train, y_train)

# Testing the model
pred = wbn.predict(x_test)

# Reverse encode the labels
y_pred = wbn.reverse_encode(target=pred)

Constructing a new dataset:

import pickle

# Import data structures for dataset creation
from wbn.object import Document, DocumentData, Documents

# Load your dataset from csv or pickle
with open("dataset.pickle"), "rb") as infile:
    raw_data = pickle.load(infile)

# De-structure 'data' and 'target'
data = raw_data.get("data")
target = raw_data.get("target")

# Construct Document's for each data/target entry
documents = Documents(
    [
        Document(DocumentData(paragraphs, keywords), target[idx])
        for idx, (paragraphs, keywords) in enumerate(data)
    ]
)

Credits

This package was created with Cookiecutter and the audreyr/cookiecutter-pypackage project template.