-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
39 changed files
with
827 additions
and
139 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.MFM import train_MFM,test_MFM | ||
from fusions.common_fusions import Concat | ||
from unimodals.MVAE import LeNetEncoder,DeLeNet | ||
from unimodals.common_models import MLP | ||
from torch import nn | ||
import torch | ||
from objective_functions.recon import recon_weighted_sum,sigmloss1dcentercrop | ||
from datasets.avmnist.get_data import get_dataloader | ||
|
||
|
||
|
||
traindata, validdata, testdata = get_dataloader('/data/yiwei/avmnist/_MFAS/avmnist') | ||
channels=6 | ||
|
||
classes=10 | ||
n_latent=200 | ||
fuse=Concat() | ||
|
||
encoders=[LeNetEncoder(1,channels,3,n_latent,twooutput=False).cuda(),LeNetEncoder(1,channels,5,n_latent,twooutput=False).cuda()] | ||
decoders=[DeLeNet(1,channels,3,n_latent).cuda(),DeLeNet(1,channels,5,n_latent).cuda()] | ||
|
||
intermediates=[MLP(n_latent,n_latent//2,n_latent//2).cuda(),MLP(n_latent,n_latent//2,n_latent//2).cuda(),MLP(2*n_latent,n_latent,n_latent//2).cuda()] | ||
head=MLP(n_latent//2,40,classes).cuda() | ||
recon_loss=recon_weighted_sum([sigmloss1dcentercrop(28,34),sigmloss1dcentercrop(112,130)],[1.0,1.0]) | ||
train_MFM(encoders,decoders,head,intermediates,fuse,recon_loss,traindata,validdata,25) | ||
model=torch.load('best.pt') | ||
test_MFM(model,testdata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.MVAE_mixed import train_MVAE,test_MVAE | ||
from fusions.MVAE import ProductOfExperts | ||
from unimodals.common_models import MLP | ||
from unimodals.MVAE import LeNetEncoder,DeLeNet | ||
from torch import nn | ||
import torch | ||
from objective_functions.recon import elbo_loss,sigmloss1dcentercrop | ||
from datasets.avmnist.get_data import get_dataloader | ||
|
||
traindata, validdata, testdata = get_dataloader('/data/yiwei/avmnist/_MFAS/avmnist') | ||
|
||
classes=10 | ||
n_latent=200 | ||
fuse=ProductOfExperts((1,40,n_latent)) | ||
|
||
|
||
channels=6 | ||
encoders=[LeNetEncoder(1,channels,3,n_latent).cuda(),LeNetEncoder(1,channels,5,n_latent).cuda()] | ||
decoders=[DeLeNet(1,channels,3,n_latent).cuda(),DeLeNet(1,channels,5,n_latent).cuda()] | ||
head=MLP(n_latent,40,classes).cuda() | ||
elbo=elbo_loss([sigmloss1dcentercrop(28,34),sigmloss1dcentercrop(112,130)],[1.0,1.0],0.0) | ||
train_MVAE(encoders,decoders,head,fuse,traindata,validdata,elbo,20) | ||
mvae=torch.load('best1.pt') | ||
head=torch.load('best2.pt') | ||
test_MVAE(mvae,head,testdata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.architecture_search import train,test | ||
from fusions.common_fusions import Concat | ||
from datasets.avmnist.get_data import get_dataloader | ||
from unimodals.common_models import LeNet,MLP,Constant | ||
from torch import nn | ||
import torch | ||
import utils.surrogate as surr | ||
|
||
traindata, validdata, testdata = get_dataloader('/data/yiwei/avmnist/_MFAS/avmnist',batch_size=32) | ||
model = torch.load('temp/best.pt').cuda() | ||
test(model,testdata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.gradient_blend import train, test | ||
from fusions.common_fusions import Concat | ||
from datasets.avmnist.get_data import get_dataloader | ||
from unimodals.common_models import LeNet,MLP,Constant | ||
from torch import nn | ||
import torch | ||
|
||
filename='best3.pt' | ||
traindata, validdata, testdata = get_dataloader('/data/yiwei/avmnist/_MFAS/avmnist') | ||
channels=6 | ||
encoders=[LeNet(1,channels,3).cuda(),LeNet(1,channels,5).cuda()] | ||
mult_head=MLP(channels*40,100,10).cuda() | ||
uni_head = [MLP(channels*8,100,10).cuda(),MLP(channels*32,100,10).cuda()] | ||
|
||
fusion=Concat().cuda() | ||
|
||
train(encoders,mult_head,uni_head,fusion,traindata,validdata,300,gb_epoch=10,optimtype=torch.optim.SGD,lr=0.01,savedir=filename) | ||
|
||
print("Testing:") | ||
model=torch.load(filename).cuda() | ||
test(model,testdata) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.Simple_Late_Fusion import train, test | ||
from fusions.common_fusions import LowRankTensorFusion | ||
from datasets.avmnist.get_data import get_dataloader | ||
from unimodals.common_models import LeNet,MLP,Constant | ||
from torch import nn | ||
import torch | ||
filename = 'lowrank.pt' | ||
traindata, validdata, testdata = get_dataloader('/data/yiwei/avmnist/_MFAS/avmnist') | ||
channels=6 | ||
encoders=[LeNet(1,channels,3).cuda(),LeNet(1,channels,5).cuda()] | ||
head=MLP(channels*20,100,10).cuda() | ||
|
||
fusion=LowRankTensorFusion([channels*8,channels*32],channels*20,40).cuda() | ||
|
||
train(encoders,fusion,head,traindata,validdata,30,optimtype=torch.optim.SGD,lr=0.05,weight_decay=0.0002,save=filename) | ||
|
||
print("Testing:") | ||
model=torch.load(filename).cuda() | ||
test(model,testdata) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.Simple_Late_Fusion import train, test | ||
from fusions.common_fusions import Concat, MultiplicativeInteractions2Modal | ||
from datasets.avmnist.get_data import get_dataloader | ||
from unimodals.common_models import LeNet,MLP,Constant | ||
from torch import nn | ||
import torch | ||
|
||
filename='bestmi.pt' | ||
traindata, validdata, testdata = get_dataloader('/data/yiwei/avmnist/_MFAS/avmnist') | ||
channels=6 | ||
encoders=[LeNet(1,channels,3).cuda(),LeNet(1,channels,5).cuda()] | ||
head=MLP(channels*40,100,10).cuda() | ||
|
||
#fusion=Concat().cuda() | ||
fusion = MultiplicativeInteractions2Modal([channels*8,channels*32],channels*40,'matrix') | ||
|
||
train(encoders,fusion,head,traindata,validdata,20,optimtype=torch.optim.SGD,lr=0.05,weight_decay=0.0001,save=filename) | ||
|
||
print("Testing:") | ||
model=torch.load(filename).cuda() | ||
test(model,testdata) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.architecture_search import train | ||
from fusions.common_fusions import Concat | ||
from datasets.mimic.get_data import get_dataloader | ||
from unimodals.common_models import LeNet,MLP,Constant,GRUWithLinear | ||
from torch import nn | ||
import torch | ||
import utils.surrogate as surr | ||
|
||
traindata, validdata, testdata = get_dataloader(1, imputed_path='datasets/mimic/im.pk') | ||
|
||
|
||
s_data=train(['pretrained/mimic/static_encoder_mortality.pt','pretrained/mimic/ts_encoder_mortality.pt'],16,2,[(5,10,10),(288,720,360)], | ||
traindata,validdata,surr.SimpleRecurrentSurrogate().cuda(),(3,3,2),epochs=6) | ||
|
||
""" | ||
print("Testing:") | ||
model=torch.load('best.pt').cuda() | ||
test(model,testdata) | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.architecture_search import train,test | ||
from fusions.common_fusions import Concat | ||
from datasets.mimic.get_data import get_dataloader | ||
from unimodals.common_models import LeNet,MLP,Constant | ||
from torch import nn | ||
import torch | ||
import utils.surrogate as surr | ||
|
||
traindata, validdata, testdata = get_dataloader(1, imputed_path='datasets/mimic/im.pk') | ||
|
||
model = torch.load('temp/best.pt').cuda() | ||
test(model,testdata,auprc=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import sys | ||
import os | ||
sys.path.append(os.getcwd()) | ||
from training_structures.Simple_Late_Fusion import train, test | ||
from fusions.common_fusions import LowRankTensorFusion | ||
from datasets.mimic.get_data import get_dataloader | ||
from unimodals.common_models import MLP, GRU | ||
from torch import nn | ||
import torch | ||
|
||
#get dataloader for icd9 classification task 7 | ||
traindata, validdata, testdata = get_dataloader(1, imputed_path='datasets/mimic/im.pk') | ||
|
||
#build encoders, head and fusion layer | ||
encoders = [MLP(5, 10, 10,dropout=False).cuda(), GRU(12, 30,dropout=False).cuda()] | ||
head = MLP(100, 40, 2, dropout=False).cuda() | ||
fusion = LowRankTensorFusion([10,720],100,40).cuda() | ||
|
||
#train | ||
train(encoders, fusion, head, traindata, validdata, 50, auprc=True) | ||
|
||
#test | ||
print("Testing: ") | ||
model = torch.load('best.pt').cuda() | ||
test(model, testdata, auprc=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Oops, something went wrong.