Skip to content

Commit

Permalink
Added rule for arbitrary SVD (re-)computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Joloco109 committed Aug 8, 2023
1 parent 81c971a commit 3a849e8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 17 deletions.
2 changes: 1 addition & 1 deletion workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include: "tools/Snakefile.smk"

report: "report/workflow.rst"

ruleorder: load_GN > load_mSM > trial_selection > locaNMF > parcellate > thresholding > feature_grouping > feature_calculation
ruleorder: load_GN > load_mSM > unify > svd > trial_selection > locaNMF > parcellate > thresholding > feature_grouping > feature_calculation

### Output accumulation rules ###

Expand Down
72 changes: 57 additions & 15 deletions workflow/rules/loading.smk
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ def sessions_input_gerion(wildcards):
'''
Matches the corresponding input files for one subject and date in the the {session_id} wildcard, task_flatten is only passed to make sure files are present. Rules uses the structured task files from the params.
'''
session_id = wildcards['session_id'].split('-')
session_id = wildcards['session_id']
sessions = config["dataset_sessions"][config["dataset_aliases"]["All"]]
if session_id not in sessions:
raise ValueError(f"{session_id=} not in {sessions=}")
session_id = session_id.split('-')
subject_id = session_id[0]
date = '-'.join(session_id[1:])
input = {
Expand All @@ -23,7 +27,11 @@ def sessions_input_simon(wildcards):
'''
Matches the corresponding input files for one subject and date in the the {subject_date} wildcard, task_flatten is only passed to make sure files are present. Rules uses the structured task files from the params.
'''
session_id = wildcards['session_id'].split('-')
session_id = wildcards['session_id']
sessions = config["dataset_sessions"][dataset_aliases["All"]]
if session_id not in sessions:
raise ValueError(f"{session_id=} not in {sessions=}")
session_id = session_id.split('-')
subject_id = session_id[0]
date = session_id[1:]
input = {
Expand All @@ -36,10 +44,10 @@ def sessions_input_simon(wildcards):

rule load_Random:
output:
"results/random.random/SVD/data.h5",
config = "results/random.random/SVD/conf.yaml"
f"results/random.random/{config['loaded_decomposition']}/data.h5",
config = f"results/random.random/{config['loaded_decomposition']}/conf.yaml"
log:
f"results/random.random/SVD/random.log"
f"results/random.random/{config['loaded_decomposition']}/random.log"
conda:
"../envs/environment.yaml"
script:
Expand All @@ -52,14 +60,14 @@ rule load_GN:
input:
unpack(sessions_input_gerion)
output:
temp_c(f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/data.h5", rule="load"),
align_plot = report(f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/alignment.pdf", caption="../report/alignment.rst", category="1 Brain Alignment", labels={"Dataset": "GN", "Subjects":"{{session_id}}"}),
config = f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/conf.yaml",
stim_side = report(f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/stim_side.pdf", caption="../report/alignment.rst", category="0 Loading", labels={"Dataset": "GN", "Subjects":"{{session_id}}"})
temp_c(f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/data.h5", rule="load"),
align_plot = report(f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/alignment.pdf", caption="../report/alignment.rst", category="1 Brain Alignment", labels={"Dataset": "GN", "Subjects":"{{session_id}}"}),
config = f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/conf.yaml",
stim_side = report(f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/stim_side.pdf", caption="../report/alignment.rst", category="0 Loading", labels={"Dataset": "GN", "Subjects":"{{session_id}}"})
wildcard_constraints:
session_id = r"GN[\w_.\-]*"
log:
f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/pipeline_entry.log"
f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/pipeline_entry.log"
conda:
"../envs/environment.yaml"
resources:
Expand All @@ -75,20 +83,54 @@ rule load_mSM:
input:
unpack(sessions_input_simon)
output:
temp_c(f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/data.h5", rule="load"),
align_plot = report(f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/alignment.pdf", caption="../report/alignment.rst", category="1 Brain Alignment", labels={"Dataset": "mSM", "Subjects":"{{session_id}}"}),
config = f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/conf.yaml",
temp_c(f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/data.h5", rule="load"),
align_plot = report(f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/alignment.pdf", caption="../report/alignment.rst", category="1 Brain Alignment", labels={"Dataset": "mSM", "Subjects":"{{session_id}}"}),
config = f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/conf.yaml",
wildcard_constraints:
session_id = r"mSM[\w_.\-]*"
log:
f"{DATA_DIR}/{{session_id}}/SVD/{{session_id}}/pipeline_entry.log"
f"{DATA_DIR}/{{session_id}}/{config['loaded_decomposition']}/{{session_id}}/pipeline_entry.log"
conda:
"../envs/environment.yaml"
resources:
mem_mib=lambda wildcards, input, attempt: mem_res(wildcards,input,attempt,4000,1000)
script:
"../scripts/loading/load_mSM.py"

def svd_input(wildcards):
if config['loaded_decomposition'] == "SVD":
raise ValueError("This output should be generated by loading")
session_id = wildcards['dataset_id']
sessions = config["dataset_sessions"][config["dataset_aliases"]["All"]]
if session_id not in sessions:
raise ValueError(f"{session_id=} not in {sessions=}")
input = {
"data" : f"{{data_dir}}/{{dataset_id}}/{config['loaded_decomposition']}/{{dataset_id}}/data.h5",
"config": f"{{data_dir}}/{{dataset_id}}/{config['loaded_decomposition']}/{{dataset_id}}/conf.yaml" }
input.update( config["paths"]["parcellations"].get('SVD') )
return input

rule svd:
'''
Perform SVD on any DecompData object
'''
input:
unpack(svd_input)
output:
temp_c("{data_dir}/{dataset_id}/SVD/{dataset_id}/data.h5", rule="parcellate"),
config = "{data_dir}/{dataset_id}/SVD/{dataset_id}/conf.yaml",
params:
params = {}
log:
"{data_dir}/{dataset_id}/SVD/{dataset_id}/parcellation.log"
conda:
"../envs/environment.yaml"
threads: min(workflow.cores,24)
resources:
mem_mib=lambda wildcards, input, attempt: mem_res(wildcards,input,attempt,16000,8000)
script:
"../scripts/loading/svd.py"

def input_unification(wildcards):
dataset_id = wildcards['dataset_id']
# if dataset_id is a defined alias, replace it with the canonical (hash) id, else leave it be
Expand Down Expand Up @@ -123,7 +165,7 @@ rule unify:
method=config["unification_method"]
conda:
"../envs/environment.yaml"
threads: 8
threads: min(workflow.cores,24)
resources:
mem_mib=lambda wildcards, input, attempt: mem_res(wildcards,input,attempt,16000,8000)
script:
Expand Down
12 changes: 11 additions & 1 deletion workflow/tools/Snakefile.smk
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ include: "../rules/common.smk"
#print(f"{dataset_groups=}")
#print(f"{dataset_aliases=}")

loaded_decomposition = config['branch_opts'].get('loading', {}).get('loaded_decomposition', 'SVD')

unification_conf = config['branch_opts'].get('unification', {})
print(f"{unification_conf=}")
unified_space = unification_conf.get('unified_space', 'All')
include_individual_sessions = unification_conf.get('include_indiviual_sessions', False)
include_individual_sessions = unification_conf.get('include_individual_sessions', False)
include_subsets = unification_conf.get('include_subsets', True)
unification_method = unification_conf.get('unification_method', 'sv_weighted')

Expand All @@ -51,6 +54,10 @@ for set_id, sub_ids in unification_groups.items():
else:
sub_datasets = [set_id]
session_runs[set_id] = sub_datasets

if include_individual_sessions:
for session in dataset_sessions[dataset_aliases['All']]:
session_runs[session] = [session]
print(f"{session_runs=}")

parcells_conf = config["branch_opts"]["parcellations"]
Expand Down Expand Up @@ -94,6 +101,9 @@ run_id = hash_config(config)

config["loading"] = {"datasets" : datasets,
"dataset_aliases" : dataset_aliases,
"dataset_sessions" : dataset_sessions,
"loaded_decomposition" : loaded_decomposition,
"parcellations": parcellations,
"unification_method" : unification_method,
} #"subject_dates" :subject_dates}

Expand Down

0 comments on commit 3a849e8

Please sign in to comment.