Skip to content
/ RAN Public

RAN: Recurrent Attention Networks for Long-text Modeling | Findings of ACL23

License

Notifications You must be signed in to change notification settings

4AI/RAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

52 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

RAN: Recurrent Attention Network

πŸ“’ This project is still in the works in order to make long document modeling easier.

The framework of RAN

⬇️ Installation

stable

python -m pip install -U rannet

latest

python -m pip install git+https://github.com/4AI/RAN.git

environment

  • ⭐ tensorflow>2.0,<=2.10 πŸ€— export TF_KERAS=1
  • tensorflow>=1.14,<2.0 πŸ€— Keras==2.3.1

πŸ›οΈ Pretrained Models

V3 Models

🎯 compatible with: rannet>0.2.1

Lang Google Drive Baidu NetDrive
EN base base[code: udts]

Chinese Models are still pretraining...

V2 Models

🎯 compatible with: rannet<=0.2.1

Lang Google Drive Baidu NetDrive
EN base base[code: djkj]
CN base | small base[code: e47w] | small[code: mdmg]

V1 Models

V1 models are not open.

πŸš€ Quick Tour

🈢 w/ pretrained models

Extract semantic feature

set return_sequences=False to extract semantic feature.

import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path,
    checkpoint_path=ckpt_path,
    return_sequences=False,
    apply_cell_transform=False,
    cell_pooling='mean'
)
text = 'input text'
tok = tokenizer.encode(text)
vec = rannet_model.predict(np.array([tok.ids]))

For the classification task

from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path, checkpoint_path=ckpt_path, return_sequences=False)
output = rannet_model.output  # (B, D)
output = L.Dropout(0.1)(output)
output = L.Dense(2, activation='softmax')(output)
model = keras.models.Model(rannet_model.input, output)
model.summary()

For the sequence task

from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path, checkpoint_path=ckpt_path, return_cell=False)
output = rannet_model.output  # (B, L, D)
rannet_model.summary()

🈚 w/o pretrained models

Embed the RAN (a Keras layer) into your network.

from rannet import RAN

ran = RAN(head_num=8,
          head_size=256,
          window_size=256,
          min_window_size=16,
          activation='swish',
          kernel_initializer='glorot_normal',
          apply_lm_mask=False,
          apply_seq2seq_mask=False,
          apply_memory_review=True,
          dropout_rate=0.0,
          cell_initializer_type='zero')
output, cell = ran(X)

w/ history

import numpy as np
from rannet import RanNet, RanNetWordPieceTokenizer


vocab_path = 'pretrained/vocab.txt'
ckpt_path = 'pretrained/model.ckpt'
config_path = 'pretrained/config.json'
tokenizer = RanNetWordPieceTokenizer(vocab_path, lowercase=True)

rannet, rannet_model = RanNet.load_rannet(
    config_path=config_path,
    checkpoint_path=ckpt_path,
    return_sequences=False,
    apply_cell_transform=False,
    return_history=True,  # return history
    cell_pooling='mean',
    with_cell=True,  # with cell input
)
rannet_model.summary()

text = 'sentence 1'
tok = tokenizer.encode(text)
init_cell = np.zeros((1, 768))  # 768 is embedding size
vec, history = rannet_model.predict([np.array([tok.ids]), init_cell])

text2 = 'sentence 2'
tok = tokenizer.encode(text2)
vec2, history = rannet_model.predict([np.array([tok.ids]), history])  # input history of sentence 1

πŸ“š Citation

If you use our code in your research, please cite our work:

@inproceedings{li-etal-2023-recurrent,
    title = "Recurrent Attention Networks for Long-text Modeling",
    author = "Li, Xianming  and
      Li, Zongxi  and
      Luo, Xiaotian  and
      Xie, Haoran  and
      Lee, Xing  and
      Zhao, Yingbin  and
      Wang, Fu Lee  and
      Li, Qing",
    booktitle = "Findings of the Association for Computational Linguistics: ACL 2023",
    month = jul,
    year = "2023",
    publisher = "Association for Computational Linguistics",
    pages = "3006--3019",
}

πŸ“¬ Contact

Please contact us at 1) for code problems, create a GitHub issue; 2) for paper problems, email [email protected]