Skip to content

Commit

Permalink
Make device placement be determined only by virtual placer. Make virt…
Browse files Browse the repository at this point in the history
…ual placer private to virtual scheduler. Remove device handling from graph properties. Remove hard-coded default device type from analytical_cost_estimator / virtual_scheduler.

PiperOrigin-RevId: 158625478
  • Loading branch information
tensorflower-gardener committed Jun 10, 2017
1 parent f48c3e0 commit 582ebd2
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 94 deletions.
5 changes: 1 addition & 4 deletions tensorflow/core/grappler/costs/analytical_cost_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
}
}
std::vector<string> inaccurate_nodes;
VirtualPlacer placer(cluster_);
VirtualScheduler scheduler(&item, use_static_shapes_,
"CPU" /* default_device_type */, cluster_,
&placer);
VirtualScheduler scheduler(&item, use_static_shapes_, cluster_);
auto status = scheduler.Init();
if (!status.ok()) {
costs->execution_time = Costs::Duration::max();
Expand Down
19 changes: 0 additions & 19 deletions tensorflow/core/grappler/costs/graph_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,6 @@ Status GraphProperties::InferStatically() {
output_properties.push_back(properties);
}
output_properties_[node->name()] = output_properties;

if (!node->assigned_device_name().empty()) {
device_names_[node->name()] = node->assigned_device_name();
} else if (!node->requested_device().empty()) {
device_names_[node->name()] = node->requested_device();
} else {
device_names_[node->name()] = "not set";
}
}

return Status::OK();
Expand Down Expand Up @@ -250,9 +242,6 @@ Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
FindInputFeatures(node, name_to_cost, name_to_node);

input_properties_[node.name()] = inputs;

const CostGraphDef::Node* cost_node = it->second;
device_names_[node.name()] = cost_node->device();
}
return Status::OK();
}
Expand All @@ -279,13 +268,5 @@ std::vector<OpInfo::TensorProperties> GraphProperties::GetOutputProperties(
return std::vector<OpInfo::TensorProperties>();
}

string GraphProperties::GetDeviceName(const string& node_name) const {
auto it = device_names_.find(node_name);
if (it != device_names_.end()) {
return it->second;
}
return "";
}

} // end namespace grappler
} // end namespace tensorflow
2 changes: 0 additions & 2 deletions tensorflow/core/grappler/costs/graph_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ class GraphProperties {
const string& node_name) const;
std::vector<OpInfo::TensorProperties> GetOutputProperties(
const string& node_name) const;
string GetDeviceName(const string& node_name) const;

private:
// Inputs
GrapplerItem item_;
std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
std::map<string, string> device_names_;
};

} // end namespace grappler
Expand Down
35 changes: 25 additions & 10 deletions tensorflow/core/grappler/costs/virtual_placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,31 @@ namespace grappler {
VirtualPlacer::VirtualPlacer(const Cluster* cluster) : has_gpu_(false) {
CHECK(cluster);
devices_ = cluster->GetDevices();
for (const auto& device : devices_) {
if (str_util::Lowercase(device.first).find("gpu") != string::npos) {
has_gpu_ = true;
unknown_device_.set_type("UNKNOWN");

if (devices_.empty()) {
// If there are no devices in the cluster, add a single device, "UNKNOWN" to
// the cluster.
default_device_ = "UNKNOWN";
devices_["UNKNOWN"] = unknown_device_;
} else {
for (const auto& device : devices_) {
if (str_util::Lowercase(device.first).find("gpu") != string::npos) {
has_gpu_ = true;
default_device_ = device.first;
break;
}
}
}

unknown_device_.set_type("UNKNOWN");
// If doesn't have gpu, set default device to be the cpu.
if (!has_gpu_) {
default_device_ = devices_.begin()->first;
}
}
}

const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
string device = get_canonical_device_name(node);
if (device.empty()) {
return unknown_device_;
}
auto it = devices_.find(device);
DCHECK(it != devices_.end());
return it->second;
Expand All @@ -66,7 +77,7 @@ string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const {
}
}
if (!parsed) {
return "";
return get_default_device_name();
} else {
device = strings::StrCat(
"/job:", parsed_name.job, "/replica:", parsed_name.replica,
Expand All @@ -81,10 +92,14 @@ string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const {
}
}
if (devices_.find(device) == devices_.end()) {
return "";
return get_default_device_name();
}
return device;
}

const string& VirtualPlacer::get_default_device_name() const {
return default_device_;
}

} // end namespace grappler
} // end namespace tensorflow
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/costs/virtual_placer.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class VirtualPlacer {
std::unordered_map<string, DeviceProperties> devices_;
bool has_gpu_;
DeviceProperties unknown_device_;
string default_device_;
const string& get_default_device_name() const;
};

} // namespace grappler
Expand Down
61 changes: 49 additions & 12 deletions tensorflow/core/grappler/costs/virtual_placer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,38 @@ TEST(VirtualPlacerTest, LocalDevices) {
placer.get_canonical_device_name(node));
}

TEST(VirtualPlacerTest, FallBackUnknown) {
// Virtual placer falls back to "UNKNOWN" only if there are no devices in the
// cluster.
std::unordered_map<string, DeviceProperties> devices;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);

NodeDef node;
node.set_op("Conv2D");

// Device falls back to UNKNOWN since the cluster has no devices.
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
EXPECT_EQ("UNKNOWN", placer.get_canonical_device_name(node));
}

TEST(VirtualPlacerTest, FallBackCPU) {
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
cpu_device.set_type("CPU");
devices["/job:my_job/replica:0/task:0/cpu:0"] = cpu_device;
VirtualCluster cluster(devices);
VirtualPlacer placer(&cluster);

NodeDef node;
node.set_op("Conv2D");

// Device falls back to CPU since there is no GPU.
EXPECT_EQ("CPU", placer.get_device(node).type());
EXPECT_EQ("/job:my_job/replica:0/task:0/cpu:0",
placer.get_canonical_device_name(node));
}

TEST(VirtualPlacerTest, RemoteDevices) {
std::unordered_map<string, DeviceProperties> devices;
DeviceProperties cpu_device;
Expand All @@ -64,9 +96,11 @@ TEST(VirtualPlacerTest, RemoteDevices) {

NodeDef node;
node.set_op("Conv2D");
// There is no local device available
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
EXPECT_EQ("", placer.get_canonical_device_name(node));

// Device falls back to GPU.
EXPECT_EQ("GPU", placer.get_device(node).type());
EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0",
placer.get_canonical_device_name(node));

node.set_device("/job:my_job/replica:0/task:0/cpu:0");
EXPECT_EQ("CPU", placer.get_device(node).type());
Expand All @@ -78,20 +112,23 @@ TEST(VirtualPlacerTest, RemoteDevices) {
EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0",
placer.get_canonical_device_name(node));

// There is no local CPU available
// There is no local cpu available. Device falls back to GPU.
node.set_device("CPU");
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
EXPECT_EQ("", placer.get_canonical_device_name(node));
EXPECT_EQ("GPU", placer.get_device(node).type());
EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0",
placer.get_canonical_device_name(node));

node.set_device("GPU:0");
// There is no local GPU available
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
EXPECT_EQ("", placer.get_canonical_device_name(node));
// There is no local GPU available. Fall back to default GPU.
EXPECT_EQ("GPU", placer.get_device(node).type());
EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0",
placer.get_canonical_device_name(node));

// This isn't a valid name
// This isn't a valid name. Fall back to GPU.
node.set_device("/job:my_job/replica:0/task:0");
EXPECT_EQ("UNKNOWN", placer.get_device(node).type());
EXPECT_EQ("", placer.get_canonical_device_name(node));
EXPECT_EQ("GPU", placer.get_device(node).type());
EXPECT_EQ("/job:my_job/replica:0/task:0/gpu:0",
placer.get_canonical_device_name(node));
}

} // end namespace grappler
Expand Down
38 changes: 5 additions & 33 deletions tensorflow/core/grappler/costs/virtual_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,15 @@ Costs CombineCosts(const Costs& left, const Costs& right) {

VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
const string& default_device_type,
Cluster* cluster, VirtualPlacer* placer)
Cluster* cluster)
: // TODO(dyoon): Use a better way than FIFO.
ready_nodes_(new FIFOManager()),
graph_costs_(Costs::ZeroCosts()),
graph_properties_(*grappler_item),
cluster_(cluster),
grappler_item_(grappler_item),
use_static_shapes_(use_static_shapes),
default_device_type_(default_device_type),
placer_(placer) {
placer_(cluster) {
initialized_ = false;
}

Expand Down Expand Up @@ -240,11 +238,7 @@ bool VirtualScheduler::IsPersistentNode(const NodeDef* node) const {
}

string VirtualScheduler::DeviceName(const NodeDef* node) const {
CHECK(!initialized_) << "DeviceName is called after Init().";

// TODO(dyoon): integrate this part with VirtualPlacer.
return node->device().empty() ? "/device:" + default_device_type_ + ":0"
: node->device();
return placer_.get_canonical_device_name(*node);
}

string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
Expand Down Expand Up @@ -314,31 +308,9 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
NodeInfo VirtualScheduler::GetCurrNodeInfo() const {
const NodeDef* node = ready_nodes_->GetCurrNode();

// This is for compatibility; we can just use placer_->get_device() for all
// cases, once VirtualCluster is properly set up.
// Get the device from the placer.
DeviceProperties device;
if (placer_) {
device = placer_->get_device(*node);
}
if (device.type() == "UNKNOWN") {
string device_type;
int device_id;
DeviceNameUtils::ParsedName parsed;
if (!node->device().empty() &&
DeviceNameUtils::ParseFullName(node_map_.at(node).device_name,
&parsed)) {
device_type = parsed.type;
device_id = parsed.id;
} else {
device_type = default_device_type_;
device_id = 0;
}
if (device_type == "GPU") {
device = GetLocalGPUInfo(device_id);
} else if (device_type == "CPU") {
device = GetLocalCPUInfo();
}
}
device = placer_.get_device(*node);

// Special case for _Send op.
if (IsSend(*node)) {
Expand Down
9 changes: 2 additions & 7 deletions tensorflow/core/grappler/costs/virtual_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,7 @@ struct NodeInfo {
class VirtualScheduler {
public:
VirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
const string& default_device_type, Cluster* cluster,
VirtualPlacer* placer);
const bool use_static_shapes, Cluster* cluster);

// Initializes NodeState and DeviceState from grappler_item_ and
// graph_properties_.
Expand Down Expand Up @@ -222,10 +220,7 @@ class VirtualScheduler {
bool use_static_shapes_;
bool initialized_;

// TODO(dyoon): Once VirtualCluster takes care of device names properly,
// move VirtualPlacer into the scheduler; also, delete default_device_type_.
const string default_device_type_;
VirtualPlacer* placer_; // Not owned.
VirtualPlacer placer_; // owned.
};

} // namespace grappler
Expand Down
10 changes: 3 additions & 7 deletions tensorflow/core/grappler/costs/virtual_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ namespace grappler {
class TestVirtualScheduler : public VirtualScheduler {
public:
TestVirtualScheduler(const GrapplerItem* grappler_item,
const bool use_static_shapes,
const string& default_device_type, Cluster* cluster,
VirtualPlacer* placer)
: VirtualScheduler(grappler_item, use_static_shapes, default_device_type,
cluster, placer) {}
const bool use_static_shapes, Cluster* cluster)
: VirtualScheduler(grappler_item, use_static_shapes, cluster) {}

FRIEND_TEST(VirtualSchedulerTest, CalculateOutputSize);
FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
Expand Down Expand Up @@ -200,8 +197,7 @@ class VirtualSchedulerTest : public ::testing::Test {
// Call this after creating grappler_item_ and setting up dependency_.
void InitScheduler() {
scheduler_.reset(new TestVirtualScheduler(
grappler_item_.get(), true /* use_static_shapes */,
"CPU" /* default_device_type */, cluster_.get(), placer_.get()));
grappler_item_.get(), true /* use_static_shapes */, cluster_.get()));
TF_CHECK_OK(scheduler_->Init());
}

Expand Down

0 comments on commit 582ebd2

Please sign in to comment.