Skip to content

Commit

Permalink
Add code
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiShu committed Feb 28, 2018
1 parent 2079a15 commit 44acbe5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# dirt-t
A DIRT-T Approach to Unsupervised Domain Adaptation (ICLR 2018)

A note about MNIST data set: I accidentally divided images by 256 instead of 255. This will cause some differences in how the the images behave upon reshaping. Just keep that in mind.
9 changes: 5 additions & 4 deletions codebase/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,18 @@ def next_batch(self, bs):
y = self.labeler(x) if self.labels is None else self.labels[idx]
return x, y


class Mnist(object):
def __init__(self, shape=(32, 32, 3)):
"""MNIST domain train/test data
shape - (3,) HWC info
"""
print "Loading MNIST"
data = np.load(os.path.join(PATH, 'mnist.npz'))
trainx = np.concatenate((data['x_train'], data['x_valid']), axis=0)
trainy = np.concatenate((data['y_train'], data['y_valid']))
data = np.load(os.path.join(PATH, 'mnist_784.npz'))

trainx = data['x_train']
trainy = data['y_train']
trainy = np.eye(10)[trainy].astype('float32')

testx = data['x_test']
Expand Down Expand Up @@ -75,7 +77,6 @@ def resize_cast(x, shape):
resized_x = np.tile(resized_x, (1, 1, 1, C))
return resized_x


class Mnistm(object):
def __init__(self, shape=(28, 28, 3), seed=0, npc=None):
"""Mnist-M domain train/test data
Expand Down
1 change: 0 additions & 1 deletion codebase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import tensorflow as tf
import shutil
import os
import datasets
import numpy as np

def u2t(x):
Expand Down

0 comments on commit 44acbe5

Please sign in to comment.