Skip to content

Commit

Permalink
Change PseudoData according to new teacher API
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Mar 13, 2018
1 parent 44acbe5 commit cbc6def
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions codebase/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,22 @@
from itertools import izip
from utils import u2t, s2t

PATH = '/mnt/ilcompf5d0/user/rshu/data'
PATH = '/home/ruishu/data'

def get_info(domain_id, domain):
train, test = domain.train, domain.test
print '{} info'.format(domain_id)
Y_shape = None if train.labels is None else train.labels.shape
print 'Train X/Y shapes: {}, {}'.format(train.images.shape, Y_shape)
print 'Train X min/max/cast: {}, {}, {}'.format(
train.images.min(),
train.images.max(),
train.cast)
print 'Test shapes: {}, {}'.format(test.images.shape, test.labels.shape)
print 'Test X min/max/cast: {}, {}, {}\n'.format(
test.images.min(),
test.images.max(),
test.cast)

class Data(object):
def __init__(self, images, labels=None, labeler=None, cast=False):
Expand Down Expand Up @@ -38,51 +53,38 @@ def next_batch(self, bs):
y = self.labeler(x) if self.labels is None else self.labels[idx]
return x, y


class Mnist(object):
def __init__(self, shape=(32, 32, 3)):
"""MNIST domain train/test data
shape - (3,) HWC info
"""
raise NotImplementedError('Double check max/min')
print "Loading MNIST"
data = np.load(os.path.join(PATH, 'mnist_784.npz'))
train = loadmat(os.path.join(PATH, 'mnist32_train.mat'))
test = loadmat(os.path.join(PATH, 'mnist32_test.mat'))

trainx = data['x_train']
trainy = data['y_train']
trainx = train['X']
trainy = train['y']
trainy = np.eye(10)[trainy].astype('float32')

testx = data['x_test']
testy = data['y_test'].astype('int')
testx = test['X']
testy = test['y'].astype('int')
testy = np.eye(10)[testy].astype('float32')

trainx = self.resize_cast(trainx, shape)
testx = self.resize_cast(testx, shape)
trainx = trainx.reshape(-1, 32, 32, 1).astype('float32')
testx = testx.reshape(-1, 32, 32, 1).astype('float32')

self.train = Data(trainx, trainy)
self.test = Data(testx, testy)

@staticmethod
def resize_cast(x, shape):
H, W, C = shape
x = x.reshape(-1, 28, 28)

resized_x = np.empty((len(x), H, W), dtype='float32')
for i, img in enumerate(x):
# imresize returns uint8
resized_x[i] = u2t(scipy.misc.imresize(img, (H, W)))

# Retile to make RGB
resized_x = resized_x.reshape(-1, H, W, 1)
resized_x = np.tile(resized_x, (1, 1, 1, C))
return resized_x

class Mnistm(object):
def __init__(self, shape=(28, 28, 3), seed=0, npc=None):
"""Mnist-M domain train/test data
shape - (3,) HWC info
"""
raise NotImplementedError('Did not change mnistm yet')
print "Loading MNIST-M"
data = pkl.load(open(os.path.join(PATH, 'mnistm_data.pkl')))
labels = pkl.load(open(os.path.join(PATH, 'mnistm_labels.pkl')))
Expand Down Expand Up @@ -263,17 +265,17 @@ def __init__(self):
self.test = Data(testx, testy, cast=True)

class PseudoData(object):
def __init__(self, domain_id, domain, M):
def __init__(self, domain_id, domain, teacher):
"""Variable domain with psuedolabeler
domain_id - (str) {Mnist,Mnistm,Svhn,etc}
domain - (obj) {Mnist,Mnistm,Svhn,etc}
M - (TensorDict) Model used for pseudolabeling
teacher - (fn) Teacher model used for pseudolabeling
"""
print "Constructing pseudodata"
cast = 'mnist' not in domain_id
print "{} uses casting: {}".format(domain_id, cast)
labeler = tb.function(M.sess, [M.test_x], M.back_y)
labeler = teacher

self.train = Data(domain.train.images, labeler=labeler, cast=cast)
self.test = Data(domain.test.images, labeler=labeler, cast=cast)
Expand Down

0 comments on commit cbc6def

Please sign in to comment.