Skip to content

Commit

Permalink
Support derived to base conversions. (carbon-language#3487)
Browse files Browse the repository at this point in the history
This is enough to support calling methods that take a `Base` or `Base*`
as their `self`. But name lookup doesn't look in the base class yet, so
base class methods aren't actually found.
  • Loading branch information
zygoloid committed Dec 11, 2023
1 parent 6037b11 commit 87d341c
Show file tree
Hide file tree
Showing 6 changed files with 603 additions and 26 deletions.
96 changes: 96 additions & 0 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,80 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
return result_id;
}

// An inheritance path is a sequence of `BaseDecl`s in order from derived to
// base.
using InheritancePath = llvm::SmallVector<SemIR::InstId>;

// Computes the inheritance path from class `derived_id` to class `base_id`.
// Returns nullopt if `derived_id` is not a class derived from `base_id`.
static auto ComputeInheritancePath(Context& context, SemIR::TypeId derived_id,
SemIR::TypeId base_id)
-> std::optional<InheritancePath> {
// We intend for NRVO to be applied to `result`. All `return` statements in
// this function should `return result;`.
std::optional<InheritancePath> result(std::in_place);
if (!context.TryToCompleteType(derived_id)) {
// TODO: Should we give an error here? If we don't, and there is an
// inheritance path when the class is defined, we may have a coherence
// problem.
result = std::nullopt;
return result;
}
while (derived_id != base_id) {
auto derived_class_type =
context.types().TryGetAs<SemIR::ClassType>(derived_id);
if (!derived_class_type) {
result = std::nullopt;
break;
}
auto& derived_class = context.classes().Get(derived_class_type->class_id);
if (!derived_class.base_id.is_valid()) {
result = std::nullopt;
break;
}
result->push_back(derived_class.base_id);
derived_id = context.insts()
.GetAs<SemIR::BaseDecl>(derived_class.base_id)
.base_type_id;
}
return result;
}

// Performs a conversion from a derived class value or reference to a base class
// value or reference.
static auto ConvertDerivedToBase(Context& context, Parse::NodeId parse_node,
SemIR::InstId value_id,
const InheritancePath& path) -> SemIR::InstId {
// Materialize a temporary if necessary.
value_id = ConvertToValueOrRefExpr(context, value_id);

// Add a series of `.base` accesses.
for (auto base_id : path) {
auto base_decl = context.insts().GetAs<SemIR::BaseDecl>(base_id);
value_id = context.AddInst(SemIR::ClassElementAccess{
parse_node, base_decl.base_type_id, value_id, base_decl.index});
}
return value_id;
}

// Performs a conversion from a derived class pointer to a base class pointer.
static auto ConvertDerivedPointerToBasePointer(
Context& context, Parse::NodeId parse_node, SemIR::PointerType src_ptr_type,
SemIR::TypeId dest_ptr_type_id, SemIR::InstId ptr_id,
const InheritancePath& path) -> SemIR::InstId {
// Form `*p`.
ptr_id = ConvertToValueExpr(context, ptr_id);
auto ref_id = context.AddInst(
SemIR::Deref{parse_node, src_ptr_type.pointee_id, ptr_id});

// Convert as a reference expression.
ref_id = ConvertDerivedToBase(context, parse_node, ref_id, path);

// Take the address.
return context.AddInst(
SemIR::AddressOf{parse_node, dest_ptr_type_id, ref_id});
}

// Returns whether `category` is a valid expression category to produce as a
// result of a conversion with kind `target_kind`, or at most needs a temporary
// to be materialized.
Expand Down Expand Up @@ -697,6 +771,28 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
return ConvertStructToClass(context, *src_struct_type, *target_class_type,
value_id, target);
}

// An expression of type T converts to U if T is a class derived from U.
if (auto path =
ComputeInheritancePath(context, value_type_id, target.type_id);
path && !path->empty()) {
return ConvertDerivedToBase(context, parse_node, value_id, *path);
}
}

// A pointer T* converts to U* if T is a class derived from U.
if (auto target_pointer_type = target_type_inst.TryAs<SemIR::PointerType>()) {
if (auto src_pointer_type =
sem_ir.types().TryGetAs<SemIR::PointerType>(value_type_id)) {
if (auto path =
ComputeInheritancePath(context, src_pointer_type->pointee_id,
target_pointer_type->pointee_id);
path && !path->empty()) {
return ConvertDerivedPointerToBasePointer(
context, parse_node, *src_pointer_type, target.type_id, value_id,
*path);
}
}
}

if (target.type_id == SemIR::TypeId::TypeType) {
Expand Down
193 changes: 193 additions & 0 deletions toolchain/check/testdata/class/derived_to_base.carbon
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// AUTOUPDATE

base class A {
var a: i32;
}

base class B {
extend base: A;
var b: i32;
}

class C {
extend base: B;
var c: i32;
}

fn ConvertCToB(p: C*) -> B* { return p; }
fn ConvertBToA(p: B*) -> A* { return p; }
fn ConvertCToA(p: C*) -> A* { return p; }

fn ConvertValue(c: C) {
let a: A = c;
}

fn ConvertRef(c: C*) -> A* {
return &(*c as A);
}

fn ConvertInit() {
let a: A = {.base = {.base = {.a = 1}, .b = 2}, .c = 3} as C;
}

// CHECK:STDOUT: --- derived_to_base.carbon
// CHECK:STDOUT:
// CHECK:STDOUT: constants {
// CHECK:STDOUT: %.loc9: type = struct_type {.a: i32}
// CHECK:STDOUT: %.loc7: type = ptr_type {.a: i32}
// CHECK:STDOUT: %.loc14_1.1: type = struct_type {.base: A, .b: i32}
// CHECK:STDOUT: %.loc14_1.2: type = struct_type {.base: {.a: i32}*, .b: i32}
// CHECK:STDOUT: %.loc14_1.3: type = ptr_type {.base: {.a: i32}*, .b: i32}
// CHECK:STDOUT: %.loc11: type = ptr_type {.base: A, .b: i32}
// CHECK:STDOUT: %.loc19_1.1: type = struct_type {.base: B, .c: i32}
// CHECK:STDOUT: %.loc19_1.2: type = struct_type {.base: {.base: A, .b: i32}*, .c: i32}
// CHECK:STDOUT: %.loc19_1.3: type = ptr_type {.base: {.base: A, .b: i32}*, .c: i32}
// CHECK:STDOUT: %.loc16: type = ptr_type {.base: B, .c: i32}
// CHECK:STDOUT: %.loc34_48: type = struct_type {.base: {.a: i32}, .b: i32}
// CHECK:STDOUT: %.loc34_57: type = struct_type {.base: {.base: {.a: i32}, .b: i32}, .c: i32}
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
// CHECK:STDOUT: package: <namespace> = namespace {.A = %A.decl, .B = %B.decl, .C = %C.decl, .ConvertCToB = %ConvertCToB, .ConvertBToA = %ConvertBToA, .ConvertCToA = %ConvertCToA, .ConvertValue = %ConvertValue, .ConvertRef = %ConvertRef, .ConvertInit = %ConvertInit}
// CHECK:STDOUT: %A.decl = class_decl @A, ()
// CHECK:STDOUT: %A: type = class_type @A
// CHECK:STDOUT: %B.decl = class_decl @B, ()
// CHECK:STDOUT: %B: type = class_type @B
// CHECK:STDOUT: %C.decl = class_decl @C, ()
// CHECK:STDOUT: %C: type = class_type @C
// CHECK:STDOUT: %ConvertCToB: <function> = fn_decl @ConvertCToB
// CHECK:STDOUT: %ConvertBToA: <function> = fn_decl @ConvertBToA
// CHECK:STDOUT: %ConvertCToA: <function> = fn_decl @ConvertCToA
// CHECK:STDOUT: %ConvertValue: <function> = fn_decl @ConvertValue
// CHECK:STDOUT: %ConvertRef: <function> = fn_decl @ConvertRef
// CHECK:STDOUT: %ConvertInit: <function> = fn_decl @ConvertInit
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @A {
// CHECK:STDOUT: %.loc8_8.1: type = unbound_element_type A, i32
// CHECK:STDOUT: %.loc8_8.2: <unbound element of class A> = field_decl a, element0
// CHECK:STDOUT: %a: <unbound element of class A> = bind_name a, %.loc8_8.2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .a = %a
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @B {
// CHECK:STDOUT: %A.ref: type = name_ref A, file.%A
// CHECK:STDOUT: %.loc12_17.1: type = unbound_element_type B, A
// CHECK:STDOUT: %.loc12_17.2: <unbound element of class B> = base_decl A, element0
// CHECK:STDOUT: %.loc13_8.1: type = unbound_element_type B, i32
// CHECK:STDOUT: %.loc13_8.2: <unbound element of class B> = field_decl b, element1
// CHECK:STDOUT: %b: <unbound element of class B> = bind_name b, %.loc13_8.2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .base = %.loc12_17.2
// CHECK:STDOUT: .b = %b
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %B.ref: type = name_ref B, file.%B
// CHECK:STDOUT: %.loc17_17.1: type = unbound_element_type C, B
// CHECK:STDOUT: %.loc17_17.2: <unbound element of class C> = base_decl B, element0
// CHECK:STDOUT: %.loc18_8.1: type = unbound_element_type C, i32
// CHECK:STDOUT: %.loc18_8.2: <unbound element of class C> = field_decl c, element1
// CHECK:STDOUT: %c: <unbound element of class C> = bind_name c, %.loc18_8.2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .base = %.loc17_17.2
// CHECK:STDOUT: .c = %c
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @ConvertCToB(%p: C*) -> B* {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %p.ref: C* = name_ref p, %p
// CHECK:STDOUT: %.loc21_39.1: ref C = deref %p.ref
// CHECK:STDOUT: %.loc21_39.2: ref B = class_element_access %.loc21_39.1, element0
// CHECK:STDOUT: %.loc21_39.3: B* = address_of %.loc21_39.2
// CHECK:STDOUT: %.loc21_39.4: B* = converted %p.ref, %.loc21_39.3
// CHECK:STDOUT: return %.loc21_39.4
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @ConvertBToA(%p: B*) -> A* {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %p.ref: B* = name_ref p, %p
// CHECK:STDOUT: %.loc22_39.1: ref B = deref %p.ref
// CHECK:STDOUT: %.loc22_39.2: ref A = class_element_access %.loc22_39.1, element0
// CHECK:STDOUT: %.loc22_39.3: A* = address_of %.loc22_39.2
// CHECK:STDOUT: %.loc22_39.4: A* = converted %p.ref, %.loc22_39.3
// CHECK:STDOUT: return %.loc22_39.4
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @ConvertCToA(%p: C*) -> A* {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %p.ref: C* = name_ref p, %p
// CHECK:STDOUT: %.loc23_39.1: ref C = deref %p.ref
// CHECK:STDOUT: %.loc23_39.2: ref B = class_element_access %.loc23_39.1, element0
// CHECK:STDOUT: %.loc23_39.3: ref A = class_element_access %.loc23_39.2, element0
// CHECK:STDOUT: %.loc23_39.4: A* = address_of %.loc23_39.3
// CHECK:STDOUT: %.loc23_39.5: A* = converted %p.ref, %.loc23_39.4
// CHECK:STDOUT: return %.loc23_39.5
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @ConvertValue(%c: C) {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %A.ref: type = name_ref A, file.%A
// CHECK:STDOUT: %c.ref: C = name_ref c, %c
// CHECK:STDOUT: %.loc26_15.1: ref B = class_element_access %c.ref, element0
// CHECK:STDOUT: %.loc26_15.2: ref A = class_element_access %.loc26_15.1, element0
// CHECK:STDOUT: %.loc26_15.3: ref A = converted %c.ref, %.loc26_15.2
// CHECK:STDOUT: %.loc26_15.4: A = bind_value %.loc26_15.3
// CHECK:STDOUT: %a: A = bind_name a, %.loc26_15.4
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @ConvertRef(%c: C*) -> A* {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %c.ref: C* = name_ref c, %c
// CHECK:STDOUT: %.loc30_12: ref C = deref %c.ref
// CHECK:STDOUT: %A.ref: type = name_ref A, file.%A
// CHECK:STDOUT: %.loc30_15.1: ref B = class_element_access %.loc30_12, element0
// CHECK:STDOUT: %.loc30_15.2: ref A = class_element_access %.loc30_15.1, element0
// CHECK:STDOUT: %.loc30_15.3: ref A = converted %.loc30_12, %.loc30_15.2
// CHECK:STDOUT: %.loc30_10: A* = address_of %.loc30_15.3
// CHECK:STDOUT: return %.loc30_10
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @ConvertInit() {
// CHECK:STDOUT: !entry:
// CHECK:STDOUT: %A.ref: type = name_ref A, file.%A
// CHECK:STDOUT: %.loc34_38: i32 = int_literal 1
// CHECK:STDOUT: %.loc34_39.1: {.a: i32} = struct_literal (%.loc34_38)
// CHECK:STDOUT: %.loc34_47: i32 = int_literal 2
// CHECK:STDOUT: %.loc34_48.1: {.base: {.a: i32}, .b: i32} = struct_literal (%.loc34_39.1, %.loc34_47)
// CHECK:STDOUT: %.loc34_56: i32 = int_literal 3
// CHECK:STDOUT: %.loc34_57.1: {.base: {.base: {.a: i32}, .b: i32}, .c: i32} = struct_literal (%.loc34_48.1, %.loc34_56)
// CHECK:STDOUT: %C.ref: type = name_ref C, file.%C
// CHECK:STDOUT: %.loc34_57.2: ref C = temporary_storage
// CHECK:STDOUT: %.loc34_57.3: ref B = class_element_access %.loc34_57.2, element0
// CHECK:STDOUT: %.loc34_48.2: ref A = class_element_access %.loc34_57.3, element0
// CHECK:STDOUT: %.loc34_39.2: ref i32 = class_element_access %.loc34_48.2, element0
// CHECK:STDOUT: %.loc34_39.3: init i32 = initialize_from %.loc34_38 to %.loc34_39.2
// CHECK:STDOUT: %.loc34_39.4: init A = class_init (%.loc34_39.3), %.loc34_48.2
// CHECK:STDOUT: %.loc34_39.5: init A = converted %.loc34_39.1, %.loc34_39.4
// CHECK:STDOUT: %.loc34_48.3: ref i32 = class_element_access %.loc34_57.3, element1
// CHECK:STDOUT: %.loc34_48.4: init i32 = initialize_from %.loc34_47 to %.loc34_48.3
// CHECK:STDOUT: %.loc34_48.5: init B = class_init (%.loc34_39.5, %.loc34_48.4), %.loc34_57.3
// CHECK:STDOUT: %.loc34_48.6: init B = converted %.loc34_48.1, %.loc34_48.5
// CHECK:STDOUT: %.loc34_57.4: ref i32 = class_element_access %.loc34_57.2, element1
// CHECK:STDOUT: %.loc34_57.5: init i32 = initialize_from %.loc34_56 to %.loc34_57.4
// CHECK:STDOUT: %.loc34_57.6: init C = class_init (%.loc34_48.6, %.loc34_57.5), %.loc34_57.2
// CHECK:STDOUT: %.loc34_57.7: ref C = temporary %.loc34_57.2, %.loc34_57.6
// CHECK:STDOUT: %.loc34_57.8: ref C = converted %.loc34_57.1, %.loc34_57.7
// CHECK:STDOUT: %.loc34_63.1: ref B = class_element_access %.loc34_57.8, element0
// CHECK:STDOUT: %.loc34_63.2: ref A = class_element_access %.loc34_63.1, element0
// CHECK:STDOUT: %.loc34_63.3: ref A = converted %.loc34_57.8, %.loc34_63.2
// CHECK:STDOUT: %.loc34_63.4: A = bind_value %.loc34_63.3
// CHECK:STDOUT: %a: A = bind_name a, %.loc34_63.4
// CHECK:STDOUT: return
// CHECK:STDOUT: }
// CHECK:STDOUT:
Loading

0 comments on commit 87d341c

Please sign in to comment.