Skip to content
/ BiGS Public
forked from jxiw/BiGS

Bidirectional Gated State Space Models for NLP

License

Notifications You must be signed in to change notification settings

arman-hk/BiGS

 
 

Repository files navigation

Pretraining Without Attention(BiGS)
Official JAX Implementation

Paper | Hugging Face Hub | Open In Colab

BiGS

This repository contains BiGS's jax model definitions, pretrained models weights, training and fintuning code for our paper exploring using state space models for pretraining. You can find more details in our paper.

Pretraining Without Attention
Junxiong Wang, Jing Nathan Yan, Albert Gu, Alexander M.Rush
Cornell University, Cornell Tech, DeepMind

Transformers have been essential to pretraining success in NLP. While other architectures have been used, downstream accuracy is either significantly worse, or requires attention layers to match standard benchmarks such as GLUE. This work explores pretraining without attention by using recent advances in sequence routing based on state-space models (SSMs). Our proposed model, Bidirectional Gated SSM (BiGS), combines SSM layers with a multiplicative gating architecture that has been effective in simplified sequence modeling architectures. The model learns static layers that do not consider pair-wise interactions. Even so, BiGS is able to match BERT pretraining accuracy on GLUE and can be extended to long-form pretraining of 4096 tokens without approximation. Analysis shows that while the models have similar accuracy, the approach has significantly different inductive biases than BERT in terms of interactions and syntactic representations.

This repo contains:

  • 🪐 JAX implementation of BiGS and its variants,
  • 🛸 Pre-trained BiGS Models of various lengths,
  • 💥 Training scripts to train BiGS from scratch,
  • 💫 Fine-tuning scripts for GLUE tasks

Setup

You can run our models on both GPUs and TPUs.

For TPUs,

pip install -r requirements-tpu.txt

For GPUs,

pip install -r requirements-gpu.txt

Download Models

Pretrained Models

Sentence Length Trained Tokens Link
128 ~11B BiGS-11B-128
128 ~29B BiGS-29B-128
128 ~97B BiGS-97B-128
512 ~108B BiGS-108B-512
1024 ~110B BiGS-110B-1024
4096 ~110B BiGS-110B-4096

MNLI Checkpoints

Sentence Length Trained Tokens Model
128 ~11B BiGS-11B-128MNLI
128 ~29B BiGS-29B-128MNLI
128 ~97B BiGS-97B-128MNLI
512 ~108B BiGS-108B-512MNLI

Example Usage

Load Masked Language Model

import jax
from jax import numpy as jnp
from transformers import BertTokenizer
from BiGS.modeling_flax_bigs import FlaxBiGSForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = FlaxBiGSForMaskedLM.from_pretrained('JunxiongWang/BiGS_128')

text = "The goal of life is [MASK]."
encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=128)
output = model(**encoded_input)
tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10])
# output: ['happiness', 'love', 'peace', 'perfection', 'life', 'enlightenment', 'god', 'survival', 'freedom', 'good']
jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10]
# probability: [0.16052087, 0.04306792, 0.03651363, 0.03468223, 0.02927081, 0.02549769, 0.02385132, 0.02261189, 0.01672831, 0.01619471]

text = "Paris is the [MASK] of France."
encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=128)
output = model(**encoded_input)
tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:8])
# output: ['capital', 'centre', 'center', 'city', 'capitol', 'prefecture', 'headquarters', 'president', 'metropolis', 'heart']
jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10]
# probability: [0.9981787 , 0.00034076, 0.00026992, 0.00026926, 0.00017787, 0.00004816, 0.00004256, 0.00003716, 0.00003634, 0.00002893]

Load Sequence Classification Model

from BiGS.modeling_flax_bigs import FlaxBiGSForSequenceClassification
model = FlaxBiGSForSequenceClassification.from_pretrained('JunxiongWang/BiGS_512')

Load Question Answering Model

from BiGS.modeling_flax_bigs import FlaxBiGSForQuestionAnswering
model = FlaxBiGSForQuestionAnswering.from_pretrained('JunxiongWang/BiGS_512')

Load Multiple Choice Classification Model

from BiGS.modeling_flax_bigs import FlaxBiGSForMultipleChoice
model = FlaxBiGSForMultipleChoice.from_pretrained('JunxiongWang/BiGS_512')

Pretraining

See pretrain.md

Finetuning

GLUE

See GLUE.md and GLUE_freeze.md

Citation

@article{wang2022pretraining,
  title={Pretraining Without Attention},
  author={Wang, Junxiong and Yan, Jing Nathan and Gu, Albert and Rush, Alexander M},
  journal={arXiv preprint arXiv:2212.10544},
  year={2022}
}

About

Bidirectional Gated State Space Models for NLP

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%