Skip to content

Commit

Permalink
Merged commit includes the following changes: (tensorflow#7220)
Browse files Browse the repository at this point in the history
* Merged commit includes the following changes:
257930561  by yongzhe:

    Mobile LSTD TfLite Client.

--
257928126  by yongzhe:

    Mobile SSD Tflite client.

--
257921181  by menglong:

    Fix discrepancy between pre_bottleneck = {true, false}

--
257561213  by yongzhe:

    File utils.

--
257449226  by yongzhe:

    Mobile SSD Client.

--
257264654  by yongzhe:

    SSD utils.

--
257235648  by yongzhe:

    Proto bazel build rules.

--
256437262  by Menglong Zhu:

    Fix check for FusedBatchNorm op to only verify it as a prefix.

--
256283755  by yongzhe:

    Bazel build and copybara changes.

--
251947295  by yinxiao:

    Add missing interleaved option in checkpoint restore.

--
251513479  by yongzhe:

    Conversion utils.

--
248783193  by yongzhe:

    Branch protos needed for the lstd client.

--
248200507  by menglong:

    Fix proto namespace in example config

--

PiperOrigin-RevId: 257930561

* Delete BUILD
  • Loading branch information
yongzhe2160 authored and dreamdragon committed Jul 16, 2019
1 parent 395f6d2 commit 66d00a8
Show file tree
Hide file tree
Showing 28 changed files with 3,154 additions and 6 deletions.
1 change: 1 addition & 0 deletions research/lstm_object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ https://scholar.googleusercontent.com/scholar.bib?q=info:rLqvkztmWYgJ:scholar.go
* [email protected]
* [email protected]
* [email protected]
* [email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# For training on Imagenet Video with LSTM Mobilenet V1

[object_detection.protos.lstm_model] {
[lstm_object_detection.protos.lstm_model] {
train_unroll_length: 4
eval_unroll_length: 4
}
Expand Down
2 changes: 1 addition & 1 deletion research/lstm_object_detection/lstm/lstm_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def __call__(self, inputs, state, scope=None):
bottleneck_concat = lstm_utils.quantizable_concat(
[inputs, h_list[k]],
axis=3,
is_training=False,
is_training=self._is_training,
is_quantized=self._is_quantized,
scope='bottleneck_%d/quantized_concat' % k)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,19 +238,33 @@ def restore_map(self, fine_tune_checkpoint_type='lstm'):
`classification`/`detection`/`interleaved`/`lstm`.
"""
if fine_tune_checkpoint_type not in [
'classification', 'detection', 'lstm'
'classification', 'detection', 'interleaved', 'lstm',
'interleaved_pretrain'
]:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))

self._restored_networks += 1
base_network_scope = self.get_base_network_scope()
if base_network_scope:
scope_to_replace = '{0}_{1}'.format(base_network_scope,
self._restored_networks)

interleaved_model = False
for variable in tf.global_variables():
if scope_to_replace in variable.op.name:
interleaved_model = True
break

variables_to_restore = {}
for variable in tf.global_variables():
var_name = variable.op.name
if 'global_step' in var_name:
continue

# Remove FeatureExtractor prefix for classification checkpoints.
if fine_tune_checkpoint_type == 'classification':
if (fine_tune_checkpoint_type == 'classification' or
fine_tune_checkpoint_type == 'interleaved_pretrain'):
var_name = (
re.split('^' + self._extract_features_scope + '/', var_name)[-1])

Expand All @@ -260,7 +274,26 @@ def restore_map(self, fine_tune_checkpoint_type='lstm'):
fine_tune_checkpoint_type == 'detection'):
var_name = var_name.replace('FeatureMaps',
self.get_base_network_scope())
variables_to_restore[var_name] = variable

# Load interleaved checkpoint specifically.
if interleaved_model: # Interleaved LSTD.
if 'interleaved' in fine_tune_checkpoint_type:
variables_to_restore[var_name] = variable
else:
# Restore non-base layers from the first checkpoint only.
if self._restored_networks == 1:
if base_network_scope + '_' not in var_name: # LSTM and FeatureMap
variables_to_restore[var_name] = variable
if scope_to_replace in var_name:
var_name = var_name.replace(scope_to_replace, base_network_scope)
variables_to_restore[var_name] = variable
else:
# Restore from the first model of interleaved checkpoints
if 'interleaved' in fine_tune_checkpoint_type:
var_name = var_name.replace(self.get_base_network_scope(),
self.get_base_network_scope() + '_1', 1)

variables_to_restore[var_name] = variable

return variables_to_restore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_has_fused_batchnorm(self):
pad_to_multiple)
preprocessed_image = feature_extractor.preprocess(image_placeholder)
_ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
self.assertTrue(any(op.type == 'FusedBatchNorm'
self.assertTrue(any(op.type.startswith('FusedBatchNorm')
for op in tf.get_default_graph().get_operations()))

def test_variables_for_tflite(self):
Expand Down
56 changes: 56 additions & 0 deletions research/lstm_object_detection/tflite/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package(
default_visibility = ["//visibility:public"],
)

licenses(["notice"])

cc_library(
name = "mobile_ssd_client",
srcs = ["mobile_ssd_client.cc"],
hdrs = ["mobile_ssd_client.h"],
deps = [
"//protos:box_encodings_cc_proto",
"//protos:detections_cc_proto",
"//protos:labelmap_cc_proto",
"//protos:mobile_ssd_client_options_cc_proto",
"//utils:conversion_utils",
"//utils:ssd_utils",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@com_google_glog//:glog",
"@gemmlowp",
],
)

cc_library(
name = "mobile_ssd_tflite_client",
srcs = ["mobile_ssd_tflite_client.cc"],
hdrs = ["mobile_ssd_tflite_client.h"],
deps = [
":mobile_ssd_client",
"//protos:anchor_generation_options_cc_proto",
"//utils:file_utils",
"//utils:ssd_utils",
"@com_google_absl//absl/memory",
"@com_google_glog//:glog",
"@org_tensorflow//tensorflow/lite:arena_planner",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)

cc_library(
name = "mobile_lstd_tflite_client",
srcs = ["mobile_lstd_tflite_client.cc"],
hdrs = ["mobile_lstd_tflite_client.h"],
deps = [
":mobile_ssd_client",
":mobile_ssd_tflite_client",
"@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog",
],
alwayslink = 1,
)
120 changes: 120 additions & 0 deletions research/lstm_object_detection/tflite/WORKSPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
workspace(name = "lstm_object_detection")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "bazel_skylib",
sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d",
strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"],
)
load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "0.23.0")

# ABSL cpp library.
http_archive(
name = "com_google_absl",
urls = [
"https://github.com/abseil/abseil-cpp/archive/a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a.tar.gz",
],
sha256 = "d437920d1434c766d22e85773b899c77c672b8b4865d5dc2cd61a29fdff3cf03",
strip_prefix = "abseil-cpp-a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a",
)

# GoogleTest/GoogleMock framework. Used by most unit-tests.
http_archive(
name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/master.zip"],
strip_prefix = "googletest-master",
)

# gflags needed by glog
http_archive(
name = "com_github_gflags_gflags",
sha256 = "6e16c8bc91b1310a44f3965e616383dbda48f83e8c1eaa2370a215057b00cabe",
strip_prefix = "gflags-77592648e3f3be87d6c7123eb81cbad75f9aef5a",
urls = [
"https://mirror.bazel.build/github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
"https://github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
],
)

# glog
http_archive(
name = "com_google_glog",
sha256 = "f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c",
strip_prefix = "glog-0.4.0",
urls = ["https://github.com/google/glog/archive/v0.4.0.tar.gz"],
)

http_archive(
name = "zlib",
build_file = "@com_google_protobuf//:third_party/zlib.BUILD",
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
strip_prefix = "zlib-1.2.11",
urls = ["https://zlib.net/zlib-1.2.11.tar.gz"],
)

http_archive(
name = "gemmlowp",
sha256 = "6678b484d929f2d0d3229d8ac4e3b815a950c86bb9f17851471d143f6d4f7834",
strip_prefix = "gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3",
urls = [
"https://mirror.tensorflow.org/github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
"https://github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
],
)

#-----------------------------------------------------------------------------
# proto
#-----------------------------------------------------------------------------
# proto_library, cc_proto_library and java_proto_library rules implicitly depend
# on @com_google_protobuf//:proto, @com_google_protobuf//:cc_toolchain and
# @com_google_protobuf//:java_toolchain, respectively.
# This statement defines the @com_google_protobuf repo.
http_archive(
name = "com_google_protobuf",
strip_prefix = "protobuf-3.8.0",
urls = ["https://github.com/google/protobuf/archive/v3.8.0.zip"],
sha256 = "1e622ce4b84b88b6d2cdf1db38d1a634fe2392d74f0b7b74ff98f3a51838ee53",
)

# java_lite_proto_library rules implicitly depend on
# @com_google_protobuf_javalite//:javalite_toolchain, which is the JavaLite proto
# runtime (base classes and common utilities).
http_archive(
name = "com_google_protobuf_javalite",
strip_prefix = "protobuf-384989534b2246d413dbcd750744faab2607b516",
urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"],
sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
)

#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )

# Needed by TensorFlow
http_archive(
name = "io_bazel_rules_closure",
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
urls = [
"https://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
],
)


# TensorFlow r1.14-rc0
http_archive(
name = "org_tensorflow",
strip_prefix = "tensorflow-1.14.0-rc0",
sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10",
urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"],
)

load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(tf_repo_name = "org_tensorflow")
Loading

0 comments on commit 66d00a8

Please sign in to comment.