-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit c6fe4f3
Showing
28 changed files
with
3,711 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
*__pycache__ | ||
*.pyc | ||
*.log | ||
*.swp | ||
tags | ||
punkt.zip | ||
wordnet.zip | ||
.idea/ | ||
aste/data/ | ||
aste/temp_data/ | ||
models/ | ||
model_outputs/ | ||
outputs/ | ||
pretrained_weight/* | ||
!pretrained_weight/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Chia Yew Ken | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
工具文档 | ||
|
||
工具功能: | ||
|
||
提取句子中的(Aspect, Opinion, Sentiment)三元组。 | ||
|
||
例如:It also has lots of other Korean dishes that are affordable and just as yummy . | ||
|
||
对应的三元组为[(Korean dishes, affordable, POSITIVE), (Korean dishes, yummy, POSITIVE)] | ||
|
||
|
||
|
||
环境配置: | ||
|
||
使用python3.7(推荐使用虚拟环境,如conda) | ||
|
||
安装依赖:bash setup.sh | ||
|
||
修改训练配置:training_config/config.json | ||
|
||
|
||
|
||
数据格式: | ||
|
||
·输入(可参考dataset/14lap.csv): | ||
|
||
输入为一个csv文件{data}.csv,每行是一个待测语句,存放在dataset目录下。 | ||
|
||
·输出(可参考pred/14lap): | ||
|
||
输出为pred/{data}目录下的csv文件。每一行的格式为: | ||
|
||
sentence#### #### ####[triplet_0, ..., triplet_n] | ||
|
||
其中每个triplet为(span_a, span_b, label),每个span是一个单词索引的列表,表示起始和结束处的索引。而label有三个可能值:’POS’, ‘NEU’, ‘NEG’,分别表示Sentiment是’Positive’, ‘Neutral’, ‘Negative’。 | ||
|
||
例如,It is a great size and amazing windows 8 included ! .#### #### ####[([4], [3], 'POS'), ([7, 8], [6], 'POS')] | ||
|
||
这个结果中,对于原句提取出的三元组为(size, great, positive), (windows 8, amazing, positive) | ||
|
||
|
||
|
||
运行: | ||
|
||
1.(可选)修改training_config/config.json中的训练参数 | ||
|
||
2.指定入口脚本main.py中的参数 | ||
|
||
参数说明: | ||
|
||
dataset: 待预测的csv文件名称。 | ||
using_train: 选择重新训练数据(using_train=True),还是直接使用pretrain_weight中的权重(using_train=False)。你可以从这里下载预训练权重,并把它放在pretrained_weight目录下。 | ||
model_name: 训练采用的数据集名称(或者使用的预训练权重的数据名称)。训练后产生的日志和权重会保存在outputs/{model_name}目录中。model_name目前的可选值有:’14lap’, ‘14res’, ‘15res’, ‘16res’ | ||
random_seed: 随机种子。 | ||
|
||
3.运行python main.py ,然后可以在pred/{data}目录下查看提取结果 | ||
|
||
|
||
|
||
引用: | ||
|
||
@inproceedings{xu-etal-2021-learning, | ||
|
||
title = "Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction", | ||
|
||
author = "Xu, Lu and | ||
|
||
Chia, Yew Ken and | ||
|
||
Bing, Lidong", | ||
|
||
booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", | ||
|
||
month = aug, | ||
|
||
year = "2021", | ||
|
||
address = "Online", | ||
|
||
publisher = "Association for Computational Linguistics", | ||
|
||
url = "https://aclanthology.org/2021.acl-long.367", | ||
|
||
doi = "10.18653/v1/2021.acl-long.367", | ||
|
||
pages = "4755--4766", | ||
|
||
abstract = "Aspect Sentiment Triplet Extraction (ASTE) is the most recent subtask of ABSA which outputs triplets of an aspect target, its associated sentiment, and the corresponding opinion term. Recent models perform the triplet extraction in an end-to-end manner but heavily rely on the interactions between each target word and opinion word. Thereby, they cannot perform well on targets and opinions which contain multiple words. Our proposed span-level approach explicitly considers the interaction between the whole spans of targets and opinions when predicting their sentiment relation. Thus, it can make predictions with the semantics of whole spans, ensuring better sentiment consistency. To ease the high computational cost caused by span enumeration, we propose a dual-channel span pruning strategy by incorporating supervision from the Aspect Term Extraction (ATE) and Opinion Term Extraction (OTE) tasks. This strategy not only improves computational efficiency but also distinguishes the opinion and target spans more properly. Our framework simultaneously achieves strong performance for the ASTE as well as ATE and OTE tasks. In particular, our analysis shows that our span-level approach achieves more significant improvements over the baselines on triplets with multi-word targets or opinions.", | ||
|
||
} | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import json | ||
import random | ||
import sys | ||
from pathlib import Path | ||
from typing import List | ||
|
||
import _jsonnet | ||
import numpy as np | ||
import torch | ||
from allennlp.commands.train import train_model | ||
from allennlp.common import Params | ||
from allennlp.data import DatasetReader, Vocabulary, DataLoader | ||
from allennlp.models import Model | ||
from allennlp.training import Trainer | ||
from fire import Fire | ||
from tqdm import tqdm | ||
|
||
from data_utils import Data, Sentence | ||
from wrapper import SpanModel, safe_divide | ||
|
||
|
||
def set_seed(seed: int): | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
|
||
def test_load( | ||
path: str = "training_config/config.json", | ||
path_train: str = "outputs/14lap/seed_0/temp_data/train.json", | ||
path_dev: str = "outputs/14lap/seed_0/temp_data/validation.json", | ||
save_dir="outputs/temp", | ||
): | ||
# Register custom modules | ||
sys.path.append(".") | ||
from span_model.data.dataset_readers.span_model import SpanModelReader | ||
|
||
assert SpanModelReader is not None | ||
params = Params.from_file( | ||
path, | ||
params_overrides=dict( | ||
train_data_path=path_train, | ||
validation_data_path=path_dev, | ||
test_data_path=path_dev, | ||
), | ||
) | ||
|
||
train_model(params, serialization_dir=save_dir, force=True) | ||
breakpoint() | ||
|
||
config = json.loads(_jsonnet.evaluate_file(path)) | ||
set_seed(config["random_seed"]) | ||
reader = DatasetReader.from_params(Params(config["dataset_reader"])) | ||
data_train = reader.read(path_train) | ||
data_dev = reader.read(path_dev) | ||
vocab = Vocabulary.from_instances(data_train + data_dev) | ||
model = Model.from_params(Params(config["model"]), vocab=vocab) | ||
|
||
data_train.index_with(vocab) | ||
data_dev.index_with(vocab) | ||
trainer = Trainer.from_params( | ||
Params(config["trainer"]), | ||
model=model, | ||
data_loader=DataLoader.from_params( | ||
Params(config["data_loader"]), dataset=data_train | ||
), | ||
validation_data_loader=DataLoader.from_params( | ||
Params(config["data_loader"]), dataset=data_dev | ||
), | ||
serialization_dir=save_dir, | ||
) | ||
breakpoint() | ||
trainer.train() | ||
breakpoint() | ||
|
||
|
||
class Scorer: | ||
name: str = "" | ||
|
||
def run(self, path_pred: str, path_gold: str) -> dict: | ||
pred = Data.load_from_full_path(path_pred) | ||
gold = Data.load_from_full_path(path_gold) | ||
assert pred.sentences is not None | ||
assert gold.sentences is not None | ||
assert len(pred.sentences) == len(gold.sentences) | ||
num_pred = 0 | ||
num_gold = 0 | ||
num_correct = 0 | ||
|
||
for i in range(len(gold.sentences)): | ||
tuples_pred = self.make_tuples(pred.sentences[i]) | ||
tuples_gold = self.make_tuples(gold.sentences[i]) | ||
num_pred += len(tuples_pred) | ||
num_gold += len(tuples_gold) | ||
for p in tuples_pred: | ||
for g in tuples_gold: | ||
if p == g: | ||
num_correct += 1 | ||
|
||
precision = safe_divide(num_correct, num_pred) | ||
recall = safe_divide(num_correct, num_gold) | ||
|
||
info = dict( | ||
precision=precision, | ||
recall=recall, | ||
score=safe_divide(2 * precision * recall, precision + recall), | ||
) | ||
return info | ||
|
||
def make_tuples(self, sent: Sentence) -> List[tuple]: | ||
raise NotImplementedError | ||
|
||
|
||
class SentimentTripletScorer(Scorer): | ||
name: str = "sentiment triplet" | ||
|
||
def make_tuples(self, sent: Sentence) -> List[tuple]: | ||
return [(t.o_start, t.o_end, t.t_start, t.t_end, t.label) for t in sent.triples] | ||
|
||
|
||
class TripletScorer(Scorer): | ||
name: str = "triplet" | ||
|
||
def make_tuples(self, sent: Sentence) -> List[tuple]: | ||
return [(t.o_start, t.o_end, t.t_start, t.t_end) for t in sent.triples] | ||
|
||
|
||
class OpinionScorer(Scorer): | ||
name: str = "opinion" | ||
|
||
def make_tuples(self, sent: Sentence) -> List[tuple]: | ||
return sorted(set((t.o_start, t.o_end) for t in sent.triples)) | ||
|
||
|
||
class TargetScorer(Scorer): | ||
name: str = "target" | ||
|
||
def make_tuples(self, sent: Sentence) -> List[tuple]: | ||
return sorted(set((t.t_start, t.t_end) for t in sent.triples)) | ||
|
||
|
||
class OrigScorer(Scorer): | ||
name: str = "orig" | ||
|
||
def make_tuples(self, sent: Sentence) -> List[tuple]: | ||
raise NotImplementedError | ||
|
||
def run(self, path_pred: str, path_gold: str) -> dict: | ||
model = SpanModel(save_dir="", random_seed=0) | ||
return model.score(path_pred, path_gold) | ||
|
||
|
||
def run_eval_domains( | ||
save_dir_template: str, | ||
path_test_template: str, | ||
random_seeds: List[int] = (0, 1, 2, 3, 4), | ||
domain_names: List[str] = ("hotel", "restaurant", "laptop"), | ||
): | ||
print(locals()) | ||
all_results = {} | ||
|
||
for domain in domain_names: | ||
results = [] | ||
for seed in tqdm(random_seeds): | ||
model = SpanModel(save_dir=save_dir_template.format(seed), random_seed=0) | ||
path_pred = str(Path(model.save_dir, f"pred_{domain}.txt")) | ||
path_test = path_test_template.format(domain) | ||
if not Path(path_pred).exists(): | ||
model.predict(path_test, path_pred) | ||
results.append(model.score(path_pred, path_test)) | ||
|
||
precision = sum(r["precision"] for r in results) / len(random_seeds) | ||
recall = sum(r["recall"] for r in results) / len(random_seeds) | ||
score = safe_divide(2 * precision * recall, precision + recall) | ||
all_results[domain] = dict(p=precision, r=recall, f=score) | ||
for k, v in all_results.items(): | ||
print(k, v) | ||
|
||
|
||
def test_scorer(path_pred: str, path_gold: str): | ||
for scorer in [ | ||
OpinionScorer(), | ||
TargetScorer(), | ||
TripletScorer(), | ||
SentimentTripletScorer(), | ||
OrigScorer(), | ||
]: | ||
print(scorer.name) | ||
print(scorer.run(path_pred, path_gold)) | ||
|
||
|
||
if __name__ == "__main__": | ||
Fire() |
Oops, something went wrong.