-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_pairs.py
47 lines (41 loc) · 1.45 KB
/
create_pairs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import random
from tqdm import tqdm
from typing import List
from datautils.utils import *
random.seed(2022)
# valid modes: ['default', 'rel_thresh']
def main(data_path: str, pairs_path: str):
data: List[dict] = read_jsonl(data_path)
posts: List[List[dict]] = list(get_posts(data).values())
if os.path.exists(pairs_path):
pairs = json.load(open(pairs_path))
return pairs
pairs = create_pairs(posts, neg_to_pos_ratio=1)
print(f"caching data at {pairs_path}")
with open(pairs_path, "w") as f:
json.dump(pairs, f, indent=4)
# print(f"found {singleton_samples} singleton samples (posts with only 1 answer)")
return pairs
if __name__ == "__main__":
os.makedirs("triples", exist_ok=True)
pairs_path = "triples/nl_code_pairs.json"
pairs = main(data_path="data/conala-mined.jsonl",
pairs_path=pairs_path)
print(f"dataset has {len(pairs)} NL-PL pairs")
val_ratio: int=0.2
val_size = int(len(pairs)*val_ratio)
stem, ext = os.path.splitext(pairs_path)
random.shuffle(pairs)
train_path = stem + "_train" + ext
val_path = stem + "_val" + ext
train_data = pairs[val_size:]
val_data = pairs[:val_size]
print(train_path)
with open(train_path, "w") as f:
json.dump(train_data, f, indent=4)
# with open(val_path, "w") as f:
# json.dump(val_data, f, indent=4)