Skip to content

CUDA implementation of Tractable Approximate Gaussian Inference

License

Notifications You must be signed in to change notification settings

lhnguyen102/cuTAGI

Repository files navigation

cuTAGI is a probabilistic array framework built upon the principles of the Tractable Approximate Gaussian Inference (TAGI) theory. It focuses on quantifying the uncertainty in Deep Neural Networks (DNNs), directly improving their reliability across supervised, unsupervised, and reinforcement learning tasks.

Some key features of cuTAGI include:

  • Performance-Oriented Kernels: All kernels of DNN layers are written in C++/CUDA from the scratch, with the utilization of pybind11 for seamless Python integration. It allows running on CPU and CUDA devices through Python API.
  • Broad Architecture Support: It currently supports the basic layer of DNNs including Linear, CNNs, Transposed CNNs, LSTM, Average Pooling, normalization, enabling the building of mainstream architectures such as Autoencoders, Transformers, Diffusion Models, and GANs.
  • Model Building and Execution: Currently, it supports sequential model building, with plans to introduce Eager Execution in the future for better debugging
  • Open Platform: cuTAGI provides open access to its entire codebase. This transparency and accessibility allows researchers and developers to dive deep into the cuTAGI's core functionalities.

cuTAGI targets machine learning researchers and developers, aiming to improve the reliability of neural network outcomes, learning efficiency, and adaptability to different dataset sizes. The Python API, inspired by the PyTorch framework, is designed to quickly onboard researchers for idea exploration.

Examples of regression task using the diagonal (top left) or full (top right) covariance modes for hidden layers, an example of heteroscedastic aleatory uncertainty inferrence (bottom left), and an example for the estimation of the derivative of a function modeled by a neural network (bottom right).

      

Examples

Here is an example for training a classifer using pytagi on MNIST dataset

from pytagi.nn import Linear, OutputUpdater, ReLU, Sequential
from pytagi import Utils, HRCSoftmaxMetric
from examples.data_loader import MnistDataloader

batch_size = 32
dtl = MnistDataLoader()
metric = HRCSoftmaxMetric(num_classes=10)

net = Sequential(
    Linear(784, 128),
    ReLU(),
    Linear(128, 128),
    ReLU(),
    Linear(128, 11),
)
#net.to_device("cuda")

udt = OutputUpdater(net.device)
var_y = np.full((batch_size * 4,), 1.0, dtype=np.float32)

batch_iter = dtl.create_data_loader(batch_size)

for i, (x, y, idx, label) in enumerate(batch_iter):
  m_pred, v_pred = net(x)
  # Update output layer based on targets
  udt.update_using_indices(net.output_z_buffer, y, var_y, idx, net.input_delta_z_buffer)
  net.backward()
  net.step()
  error_rate = metric.error_rate(m_pred, v_pred, label)
  print(f"Iteration: {i} error rate: {error_rate}")

cuTAGI offers a diverse set of examples to demonstrate its capabilities, including:

  • Regression
    python -m examples.regression
  • Classification on MNIST using various layers such as Linear, CNNs, Batch & Layer Norms.
    python -m examples.classification
  • Generation of MNIST images using an Autoencoder.
    python -m examples.autoencoder
  • Time series forecasting
    python -m examples.time_series_forecasting

Installation

cuTAGI is available on PyPI. To install, execute the following command in Terminal:

pip install pytagi

Additionally, for those interested in leveraging the full performance of the C++/CUDA native version, installation instructions are provided in the docs/dev_guide.md.

License

cuTAGI is released under the MIT license.

THIS IS AN OPEN SOURCE SOFTWARE FOR RESEARCH PURPOSES ONLY. THIS IS NOT A PRODUCT. NO WARRANTY EXPRESSED OR IMPLIED.

Related Papers

Citation

@misc{cutagi2022,
  Author = {Luong-Ha Nguyen and James-A. Goulet},
  Title = {cu{TAGI}: a {CUDA} library for {B}ayesian neural networks with Tractable Approximate {G}aussian Inference},
  Year = {2022},
  journal = {GitHub repository},
  howpublished = {https://github.com/lhnguyen102/cuTAGI}
}