Skip to content

Simple linear thing in Torch, with a scikit-learn compatible API.

License

Notifications You must be signed in to change notification settings

stephantul/torchic

Repository files navigation

Hugging Face Transformers Library

torchic

Simple model training in Pytorch, with a scikit-learn compatible API.

It has the following features:

  • Scikit-learn like API (i.e., using fit and predict)
  • Supports numpy arrays and torch tensors out of the box
  • Automatically converts your tensors between devices

Example

The example below classifies 20 newsgroups, which is pre-vectorized using a CountVectorizer, courtesy of scikit-learn. This example requires that scikit-learn is installed.

import numpy as np

from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support

from torchic import Torchic

# NOTE: change this to 'cuda' or 'mps' if you want acceleration.
DEVICE = "cpu"

X, y = fetch_20newsgroups_vectorized(return_X_y=True, remove=("headers", "footers"), subset="train")
X = X[y < 10]
y = y[y < 10]
X = np.asarray(X.todense())

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, random_state=44, test_size=.1)
n_features, n_labels = X_train.shape[1], len(set(y))

# Torchic stuff begins here.
t = Torchic(n_features, n_labels, learning_rate=1e-4).to(DEVICE)
t.fit(X_train, y_train, batch_size=128)

pred = t.predict(X_test)

print(precision_recall_fscore_support(y_test, pred, average="macro"))

TODO:

  • Add docstrings
  • Add additional unit tests

About

Simple linear thing in Torch, with a scikit-learn compatible API.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages