Skip to content

Commit

Permalink
Fix and add visualizor
Browse files Browse the repository at this point in the history
  • Loading branch information
zzh8829 committed Dec 21, 2019
1 parent 6d574cc commit 64175c4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main(_argv):
logging.info('classes loaded')

if FLAGS.tfrecord:
dataset = load_tfrecord_dataset(FLAGS.tfrecord, FLAGS.classes)
dataset = load_tfrecord_dataset(FLAGS.tfrecord, FLAGS.classes, FLAGS.size)
dataset = dataset.shuffle(1024)
img_raw, label = next(iter(dataset.take(1)))
img = tf.expand_dims(img_raw, 0)
Expand Down
7 changes: 5 additions & 2 deletions tools/visualize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,24 @@ def main(_argv):
class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
logging.info('classes loaded')

dataset = load_tfrecord_dataset(FLAGS.dataset, FLAGS.classes)
dataset = load_tfrecord_dataset(FLAGS.dataset, FLAGS.classes, FLAGS.size)
dataset = dataset.shuffle(1024)

for image, labels in dataset.take(1):
boxes = []
scores = []
classes = []
for x1, y1, x2, y2, label in labels:
if x1 == 0 and x2 == 0:
continue

boxes.append((x1, y1, x2, y2))
scores.append(1)
classes.append(label)
nums = [len(boxes)]
boxes = [boxes]
scores = [scores]
classes = [classes]
nums = [len(boxes)]

logging.info('labels:')
for i in range(nums[0]):
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(_argv):
train_dataset = dataset.load_fake_dataset()
if FLAGS.dataset:
train_dataset = dataset.load_tfrecord_dataset(
FLAGS.dataset, FLAGS.classes)
FLAGS.dataset, FLAGS.classes, FLAGS.size)
train_dataset = train_dataset.shuffle(buffer_size=1024) # TODO: not 1024
train_dataset = train_dataset.batch(FLAGS.batch_size)
train_dataset = train_dataset.map(lambda x, y: (
Expand All @@ -74,7 +74,7 @@ def main(_argv):
val_dataset = dataset.load_fake_dataset()
if FLAGS.val_dataset:
val_dataset = dataset.load_tfrecord_dataset(
FLAGS.val_dataset, FLAGS.classes)
FLAGS.val_dataset, FLAGS.classes, FLAGS.size)
val_dataset = val_dataset.batch(FLAGS.batch_size)
val_dataset = val_dataset.map(lambda x, y: (
dataset.transform_images(x, FLAGS.size),
Expand Down
8 changes: 4 additions & 4 deletions yolov3_tf2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def transform_images(x_train, size):
}


def parse_tfrecord(tfrecord, class_table):
def parse_tfrecord(tfrecord, class_table, size):
x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
x_train = tf.image.resize(x_train, (416, 416))
x_train = tf.image.resize(x_train, (size, size))

class_text = tf.sparse.to_dense(
x['image/object/class/text'], default_value='')
Expand All @@ -118,14 +118,14 @@ def parse_tfrecord(tfrecord, class_table):
return x_train, y_train


def load_tfrecord_dataset(file_pattern, class_file):
def load_tfrecord_dataset(file_pattern, class_file, size=416):
LINE_NUMBER = -1 # TODO: use tf.lookup.TextFileIndex.LINE_NUMBER
class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)

files = tf.data.Dataset.list_files(file_pattern)
dataset = files.flat_map(tf.data.TFRecordDataset)
return dataset.map(lambda x: parse_tfrecord(x, class_table))
return dataset.map(lambda x: parse_tfrecord(x, class_table, size))


def load_fake_dataset():
Expand Down

0 comments on commit 64175c4

Please sign in to comment.