Skip to content

Commit

Permalink
引擎基类,数据pipeline提交
Browse files Browse the repository at this point in the history
  • Loading branch information
ruanjz6235 committed Dec 18, 2022
1 parent 33a502a commit 4646904
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import feather
import pandas as pd
import numpy as np
import random
from sklearn.utils.random import check_random_state
from data_transform import DataTransform
from pyarrow.filesystem import FileSystem as fs
Expand Down Expand Up @@ -51,8 +52,7 @@ def __init__(self,
origins,
ret_path,
bucket_size,
ret=None,
ascending=True):
ret=None):
self.tokenizer = tokenizer
self.bucket_size = bucket_size

Expand All @@ -74,10 +74,13 @@ def __init__(self,

def __getitem__(self, index):
while True:
len_dates = torch.random
len_dates = np.random.randint(len(self.ret))
if len_dates >= 0.3 * len(self.ret):
break

dates_index = np.array(random.sample(range(len(self.ret)), len_dates))
dates_index.sort()
ret_sample = self.ret[dates_index]
origins_sample = [origin[dates_index] for origin in self.origins]
return ret_sample, origins_sample

def __len__(self):
return len(self.ret)
Expand Down

0 comments on commit 4646904

Please sign in to comment.