Skip to content

Commit

Permalink
Merge pull request #70 from EleutherAI/python_filter
Browse files Browse the repository at this point in the history
Merge commit I forgot about
  • Loading branch information
zhangir-azerbayev committed Oct 16, 2023
2 parents a8aa78f + 50cd77e commit 0222c58
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions proof_pile_2/algebraic_stack/process_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,30 +177,58 @@ def haskell_filter(example):

return has_package and standard_filter(example)

def py_filter(example):
def python_map(example):
text = example["content"]

if not standard_filter(example):
return False
return {"include": False, "reason": "standard"}

# blacklist = (
# "import matplotlib", "from matplotlib",
# "import pandas", "from pandas",
# "import polars", "from polars"
# "import tensorflow", "from tensorflow"
# "import torch", "from torch"
# )

# if any(x in text for x in blacklist):
# return {"include": False, "reason": "blacklist"}

# removes notebooks and jsons
if text.strip()[0] == "{":
return False
return {"include": False, "reason": "notebook"}

keywords = []
packages = [
"numpy",
"numpy",
"scipy",
"sympy",
"sage",
"numba",
"numexpr",
"theano",
"statsmodels",
"networkx",
"mpmath",
"pymc3",
"astropy",
"cupy",
"pycuda",
"cvxpy",
"pyomo",
"jax",
]
for pack in packages:
keywords += [f"import {pack}", f"from {pack}"]

found = [x for x in keywords if x in text]
return found
if found:
return {"include": True, "reason": ",".join(found)}
else:
return {"include": False, "reason": "not found"}

def python_filter(example):
return example["include"]


def c_filter(example):
Expand Down Expand Up @@ -491,7 +519,10 @@ def main(args):
elif lang == "haskell":
ds = ds.filter(haskell_filter, **filter_kwargs)
elif lang == "python":
ds = ds.filter(py_filter, **filter_kwargs)
print("map...")
ds = ds.map(python_map, **filter_kwargs)
print("filter...")
ds = ds.filter(python_filter, **filter_kwargs)
elif lang == "c":
ds = ds.filter(c_filter, **filter_kwargs)
elif lang == "cpp":
Expand Down

0 comments on commit 0222c58

Please sign in to comment.