Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
caisq committed Jul 1, 2017
2 parents 0261a4f + 9999dd3 commit af0ca35
Show file tree
Hide file tree
Showing 51 changed files with 2,064 additions and 476 deletions.
5 changes: 5 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ cc_library(
deps = [
":hlo",
":hlo_query",
":shape_inference",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
Expand All @@ -106,12 +108,15 @@ cc_test(
":hlo",
":hlo_evaluator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
Expand Down
55 changes: 41 additions & 14 deletions tensorflow/compiler/xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,9 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
}

StatusOr<llvm::Value*> IrEmitter::EmitTargetAddressForOp(
const HloInstruction* op) {
const Shape& target_shape = op->shape();
if (op == op->parent()->root_instruction()) {
const HloInstruction* op, const ShapeIndex& shape_index) {
const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index);
if (op == op->parent()->root_instruction() && shape_index.empty()) {
// For the root node, we write directly to the output buffer of the
// function.
llvm::Argument* retval = GetResultArgument();
Expand Down Expand Up @@ -2069,19 +2069,46 @@ Status IrEmitter::EmitTargetElementLoop(
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(target_op));
VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address);
llvm_ir::IrArray target_array(target_address, target_shape);
AddAliasingInformationToIrArray(*target_op, &target_array);

if (num_dynamic_loop_bounds_ > 0 &&
target_op == target_op->parent()->root_instruction()) {
// Emit parallel loop for root instruction if dynamic outer-dimension loop
// bounds were specified.
TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop(
target_shape, element_generator, &target_array));
} else {

if (target_op->IsMultiOutputFusion()) {
// For multiple outputs fusion, we need to emit each operand and the root.
TF_RET_CHECK(num_dynamic_loop_bounds_ == 0);
std::vector<llvm_ir::IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(target_op, {i}));
const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i});
llvm::Value* op_target_address =
EmitTempBufferPointer(slice, element_shape);
output_arrays.push_back(
llvm_ir::IrArray(op_target_address, element_shape));
}
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_)
llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_)
.EmitLoop());

std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape),
tuple_operand_ptrs, &ir_builder_);

} else {
llvm_ir::IrArray target_array(target_address, target_shape);
AddAliasingInformationToIrArray(*target_op, &target_array);

if (num_dynamic_loop_bounds_ > 0 &&
target_op == target_op->parent()->root_instruction()) {
// Emit parallel loop for root instruction if dynamic outer-dimension loop
// bounds were specified.
TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop(
target_shape, element_generator, &target_array));
} else {
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_)
.EmitLoop());
}
}

emitted_value_[target_op] = target_address;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Emit IR to compute the target address of the buffer for the given op.
// The returned Value is a pointer to a IR type that represents the op's
// element type.
StatusOr<llvm::Value*> EmitTargetAddressForOp(const HloInstruction* op);
StatusOr<llvm::Value*> EmitTargetAddressForOp(
const HloInstruction* op, const ShapeIndex& shape_index = {});

// Structurizes "array_elements" into an MD array that represents "shape".
// This is a recursive function, and "dimension_index" indicates the index of
Expand Down
25 changes: 22 additions & 3 deletions tensorflow/compiler/xla/service/gpu/fusion_merger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) {
// Calculate total bytes transferred in/out.
double bytes = CalculateBytesReadByFusionInstruction(fusion);
// Add bytes written to root instructions buffer.
bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
if (fusion->IsMultiOutputFusion()) {
for (auto& operand : fusion->fused_expression_root()->operands()) {
bytes += ShapeUtil::ByteSizeOf(operand->shape());
}
} else {
bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
}
// Calculate flops for all fused instructions. Use a null shape size function
// because we don't care about bytes accessed by the ops.
HloCostAnalysis analysis([](const Shape& shape) { return 0; });
Expand All @@ -112,8 +118,15 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) {
double GetCurrentBytesTransferred(HloInstruction* fusion) {
CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
const double bytes_read = CalculateBytesReadByFusionInstruction(fusion);
const double bytes_written =
ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
double bytes_written = 0;
if (fusion->IsMultiOutputFusion()) {
for (auto& operand : fusion->fused_expression_root()->operands()) {
bytes_written += ShapeUtil::ByteSizeOf(operand->shape());
}
} else {
bytes_written =
ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
}
// Current bytes transferred (ignoring non 'fusion' user operands) is bytes
// read and written by 'fusion', plus reads of size 'bytes_written' for each
// user.
Expand Down Expand Up @@ -198,6 +211,12 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
++num_fail_not_loop_fusion_;
return Status::OK();
}

// Skip multiple output fusion. It's not yet supported.
if (fusion->IsMultiOutputFusion()) {
++num_fail_not_loop_fusion_;
return Status::OK();
}
// Skip 'fusion' instruction if we cannot merge into all of its users.
// Merging into all users enables the removal of 'fusion' from the
// computation.
Expand Down
73 changes: 49 additions & 24 deletions tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,35 @@ void HloToIrBindings::EmitBasePointersForHlos(
continue;
}

// A non-IO HLO with a buffer is bound to
// (1) an alloca if it is thread-local, or
// (2) an internal pointer in temp_buffer_base according to its offset.
const BufferAllocation::Slice slice =
buffer_assignment_->GetUniqueTopLevelSlice(non_io_hlo)
.ConsumeValueOrDie();
if (slice.allocation()->is_thread_local()) {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_);
BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type));
} else {
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
BindHloToIrValue(*non_io_hlo,
ir_builder_->CreateInBoundsGEP(
temp_buffer_base_, ir_builder_->getInt64(offset)));
}
ShapeUtil::ForEachSubshape(
non_io_hlo->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
// A non-IO HLO with a buffer is bound to
// (1) an alloca if it is thread-local, or
// (2) an internal pointer in temp_buffer_base according to its
// offset.
auto slice_result =
buffer_assignment_->GetUniqueSlice(non_io_hlo, index);
if (!slice_result.ok()) {
return;
}
const BufferAllocation::Slice slice =
slice_result.ConsumeValueOrDie();
if (slice.allocation()->is_thread_local()) {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_);
BindHloToIrValue(*non_io_hlo,
ir_builder_->CreateAlloca(pointee_type), index);
} else {
const int64 offset = slice.offset();
CHECK_NE(nullptr, temp_buffer_base_);
BindHloToIrValue(
*non_io_hlo,
ir_builder_->CreateInBoundsGEP(temp_buffer_base_,
ir_builder_->getInt64(offset)),
index);
}
});
}
}

Expand All @@ -112,16 +124,18 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
GetTypedIrValue(*gte->operand(0), base_ptr), ir_builder_);
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_);
}
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_);
}

llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
const ShapeIndex& shape_index,
llvm::Value* ir_value) {
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(hlo.shape(), ir_builder_);
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_);
llvm::Type* dest_type = pointee_type->getPointerTo();

llvm::Value* typed_ir_value;
Expand All @@ -139,13 +153,24 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
}

void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
llvm::Value* ir_value) {
llvm::Value* ir_value,
const ShapeIndex& shape_index) {
VLOG(2) << "Binding " << hlo.ToString();
InsertOrDie(&base_ptrs_, &hlo, GetTypedIrValue(hlo, ir_value));

const Shape& hlo_shape = hlo.shape();
llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);

if (!BoundToIrValue(hlo)) {
// Set the root of ShapeTree first before assigning the element ir value.
InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
}
*(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
}

llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo) {
llvm_ir::IrArray ir_array(GetBasePointer(hlo), hlo.shape());
llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
const ShapeIndex& shape_index) {
llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index),
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
return ir_array;
}
Expand All @@ -154,7 +179,7 @@ void HloToIrBindings::UnbindAllLocalIrValues() {
std::vector<const HloInstruction*> hlos_to_unbind;
for (auto& key_value : base_ptrs_) {
if (!llvm::isa<llvm::GlobalVariable>(
key_value.second->stripPointerCasts())) {
(key_value.second.element({}))->stripPointerCasts())) {
hlos_to_unbind.push_back(key_value.first);
}
}
Expand Down
20 changes: 14 additions & 6 deletions tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class HloToIrBindings {
tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos);

// Rebinds the given HLO to the LLVM IR value that represent its address.
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value);
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
const ShapeIndex& shape_index = {});

// Unbinds all IR values that's defined in an LLVM function, e.g., function
// arguments and stack variables. Global variables will be kept in bindings_.
Expand All @@ -64,15 +65,18 @@ class HloToIrBindings {

llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; }

// A helper method that returns the base pointer of the IrArray for "inst".
llvm::Value* GetBasePointer(const HloInstruction& hlo) const {
// A helper method that returns the base pointer of the IrArray containing the
// output of "inst".at the given ShapeIndex.
llvm::Value* GetBasePointer(const HloInstruction& hlo,
const ShapeIndex& shape_index = {}) const {
auto it = base_ptrs_.find(&hlo);
CHECK(it != base_ptrs_.end());
return it->second;
return it->second.element(shape_index);
}

// Return the underlying IrArray of the output of the given instruction.
llvm_ir::IrArray GetIrArray(const HloInstruction& hlo);
llvm_ir::IrArray GetIrArray(const HloInstruction& hlo,
const ShapeIndex& shape_index = {});

private:
// Emits IR to resolve (possibly) recursive GetTupleElement instructions.
Expand All @@ -81,6 +85,7 @@ class HloToIrBindings {

// Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape.
llvm::Value* GetTypedIrValue(const HloInstruction& hlo,
const ShapeIndex& shape_index,
llvm::Value* ir_value);

const BufferAssignment* buffer_assignment_;
Expand All @@ -90,7 +95,10 @@ class HloToIrBindings {
llvm::IRBuilder<>* ir_builder_;

// Stores the underlying llvm::IrArray for each HloInstruction.
std::unordered_map<const HloInstruction*, llvm::Value*> base_ptrs_;
// For an instruction that generates multiple outputs, the root will be a
// tuple shape. The IrArray for each element output is stored in the subnode
// in the ShapeTree.
std::unordered_map<const HloInstruction*, ShapeTree<llvm::Value*>> base_ptrs_;

// The address of the memory block that contains all temporary buffers.
llvm::Value* temp_buffer_base_;
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/compiler/xla/service/gpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
IrEmitterContext* ir_emitter_context, bool is_nested);

// A convenient helper for calling HloToIrBindings::GetIrArray.
llvm_ir::IrArray GetIrArray(const HloInstruction& inst) {
return bindings_.GetIrArray(inst);
llvm_ir::IrArray GetIrArray(const HloInstruction& inst,
const ShapeIndex& shape_index = {}) {
return bindings_.GetIrArray(inst, shape_index);
}
// A convenient helper for calling HloToIrBindings::GetBasePointer.
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
Expand Down
35 changes: 29 additions & 6 deletions tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1886,15 +1886,38 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
const Shape& element_shape = hlo.IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo.shape(), {0})
: hlo.shape();
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
hlo.shape(), ir_emitter_context_->device_description());
element_shape, ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, thunk,
ir_emitter_context_->llvm_module());
// Otherwise, emit a parallel loop that computes the partition that each
// thread is in charge of.
return ParallelLoopEmitter(element_generator, GetIrArray(hlo),
launch_dimensions, &ir_builder_)
.EmitLoop();
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo),
launch_dimensions, &ir_builder_)
.EmitLoop();
}

// For multiple outputs fusion, we need to emit each operand and the root.
std::vector<llvm_ir::IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, {i}));
}
TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays,
launch_dimensions, &ir_builder_)
.EmitLoop());

std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
// const HloInstruction* root = hlo.fused_expression_root();
llvm_ir::EmitTuple(
GetIrArray(*hlo.fused_expression_root()->fusion_instruction()),
tuple_operand_ptrs, &ir_builder_);
return Status::OK();
}

Status IrEmitterUnnested::EmitTargetElementLoop(
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ ParallelLoopEmitter::ParallelLoopEmitter(
: LoopEmitter(body_emitter, shape, ir_builder),
launch_dimensions_(launch_dimensions) {}

ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> target_arrays,
const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
: LoopEmitter(target_element_generator, target_arrays, ir_builder),
launch_dimensions_(launch_dimensions) {}

ParallelLoopEmitter::ParallelLoopEmitter(
const llvm_ir::ElementGenerator& target_element_generator,
const llvm_ir::IrArray& target_array,
Expand Down
Loading

0 comments on commit af0ca35

Please sign in to comment.