Skip to content

Commit

Permalink
Simplify interface for getting an instruction from a type. (carbon-la…
Browse files Browse the repository at this point in the history
…nguage#3455)

Co-authored-by: Jon Ross-Perkins <[email protected]>
  • Loading branch information
zygoloid and jonmeow committed Dec 8, 2023
1 parent 5897e57 commit cef7eb5
Show file tree
Hide file tree
Showing 14 changed files with 211 additions and 172 deletions.
20 changes: 8 additions & 12 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class TypeCompleter {
private:
// Adds `type_id` to the work list, if it's not already complete.
auto Push(SemIR::TypeId type_id) -> void {
if (!context_.sem_ir().IsTypeComplete(type_id)) {
if (!context_.types().IsComplete(type_id)) {
work_list_.push_back({type_id, Phase::AddNestedIncompleteTypes});
}
}
Expand All @@ -553,14 +553,12 @@ class TypeCompleter {

// We might have enqueued the same type more than once. Just skip the
// type if it's already complete.
if (context_.sem_ir().IsTypeComplete(type_id)) {
if (context_.types().IsComplete(type_id)) {
work_list_.pop_back();
return true;
}

auto inst_id = context_.sem_ir().GetTypeAllowBuiltinTypes(type_id);
auto inst = context_.insts().Get(inst_id);

auto inst = context_.types().GetAsInst(type_id);
auto old_work_list_size = work_list_.size();

switch (phase) {
Expand All @@ -583,14 +581,14 @@ class TypeCompleter {
// Also complete the value representation type, if necessary. This
// should never fail: the value representation shouldn't require any
// additional nested types to be complete.
if (!context_.sem_ir().IsTypeComplete(value_rep.type_id)) {
if (!context_.types().IsComplete(value_rep.type_id)) {
work_list_.push_back({value_rep.type_id, Phase::BuildValueRepr});
}
// For a pointer representation, the pointee also needs to be complete.
if (value_rep.kind == SemIR::ValueRepr::Pointer) {
auto pointee_type_id =
context_.sem_ir().GetPointeeType(value_rep.type_id);
if (!context_.sem_ir().IsTypeComplete(pointee_type_id)) {
if (!context_.types().IsComplete(pointee_type_id)) {
work_list_.push_back({pointee_type_id, Phase::BuildValueRepr});
}
}
Expand Down Expand Up @@ -684,9 +682,9 @@ class TypeCompleter {
// Gets the value representation of a nested type, which should already be
// complete.
auto GetNestedValueRepr(SemIR::TypeId nested_type_id) const {
CARBON_CHECK(context_.sem_ir().IsTypeComplete(nested_type_id))
CARBON_CHECK(context_.types().IsComplete(nested_type_id))
<< "Nested type should already be complete";
auto value_rep = context_.sem_ir().GetValueRepr(nested_type_id);
auto value_rep = context_.types().GetValueRepr(nested_type_id);
CARBON_CHECK(value_rep.kind != SemIR::ValueRepr::Unknown)
<< "Complete type should have a value representation";
return value_rep;
Expand Down Expand Up @@ -1128,9 +1126,7 @@ auto Context::GetPointerType(Parse::NodeId parse_node,
}

auto Context::GetUnqualifiedType(SemIR::TypeId type_id) -> SemIR::TypeId {
SemIR::Inst type_inst =
insts().Get(sem_ir_->GetTypeAllowBuiltinTypes(type_id));
if (auto const_type = type_inst.TryAs<SemIR::ConstType>()) {
if (auto const_type = types().TryGetAs<SemIR::ConstType>(type_id)) {
return const_type->inner_id;
}
return type_id;
Expand Down
2 changes: 1 addition & 1 deletion toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class Context {
auto name_scopes() -> SemIR::NameScopeStore& {
return sem_ir().name_scopes();
}
auto types() -> ValueStore<SemIR::TypeId>& { return sem_ir().types(); }
auto types() -> SemIR::TypeStore& { return sem_ir().types(); }
auto type_blocks() -> SemIR::BlockValueStore<SemIR::TypeBlockId>& {
return sem_ir().type_blocks();
}
Expand Down
31 changes: 13 additions & 18 deletions toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ static auto ConvertStructToClass(Context& context, SemIR::StructType src_type,
context.sem_ir().StringifyType(target.type_id));
return SemIR::InstId::BuiltinError;
}
auto dest_struct_type = context.insts().GetAs<SemIR::StructType>(
context.sem_ir().GetTypeAllowBuiltinTypes(class_info.object_repr_id));
auto dest_struct_type =
context.types().GetAs<SemIR::StructType>(class_info.object_repr_id);

// If we're trying to create a class value, form a temporary for the value to
// point to.
Expand Down Expand Up @@ -599,8 +599,7 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
auto& sem_ir = context.sem_ir();
auto value = sem_ir.insts().Get(value_id);
auto value_type_id = value.type_id();
auto target_type_inst =
sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(target.type_id));
auto target_type_inst = sem_ir.types().GetAsInst(target.type_id);

// Various forms of implicit conversion are supported as builtin conversions,
// either in addition to or instead of `impl`s of `ImplicitAs` in the Carbon
Expand Down Expand Up @@ -660,9 +659,8 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
// A tuple (T1, T2, ..., Tn) converts to (U1, U2, ..., Un) if each Ti
// converts to Ui.
if (auto target_tuple_type = target_type_inst.TryAs<SemIR::TupleType>()) {
auto value_type_inst =
sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
if (auto src_tuple_type = value_type_inst.TryAs<SemIR::TupleType>()) {
if (auto src_tuple_type =
sem_ir.types().TryGetAs<SemIR::TupleType>(value_type_id)) {
return ConvertTupleToTuple(context, *src_tuple_type, *target_tuple_type,
value_id, target);
}
Expand All @@ -673,19 +671,17 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
// (p(1), ..., p(n)) is a permutation of (1, ..., n) and each Ti converts
// to Ui.
if (auto target_struct_type = target_type_inst.TryAs<SemIR::StructType>()) {
auto value_type_inst =
sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
if (auto src_struct_type = value_type_inst.TryAs<SemIR::StructType>()) {
if (auto src_struct_type =
sem_ir.types().TryGetAs<SemIR::StructType>(value_type_id)) {
return ConvertStructToStruct(context, *src_struct_type,
*target_struct_type, value_id, target);
}
}

// A tuple (T1, T2, ..., Tn) converts to [T; n] if each Ti converts to T.
if (auto target_array_type = target_type_inst.TryAs<SemIR::ArrayType>()) {
auto value_type_inst =
sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
if (auto src_tuple_type = value_type_inst.TryAs<SemIR::TupleType>()) {
if (auto src_tuple_type =
sem_ir.types().TryGetAs<SemIR::TupleType>(value_type_id)) {
return ConvertTupleToArray(context, *src_tuple_type, *target_array_type,
value_id, target);
}
Expand All @@ -696,9 +692,8 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
// (a struct with the same fields as the class, plus a base field where
// relevant).
if (auto target_class_type = target_type_inst.TryAs<SemIR::ClassType>()) {
auto value_type_inst =
sem_ir.insts().Get(sem_ir.GetTypeAllowBuiltinTypes(value_type_id));
if (auto src_struct_type = value_type_inst.TryAs<SemIR::StructType>()) {
if (auto src_struct_type =
sem_ir.types().TryGetAs<SemIR::StructType>(value_type_id)) {
return ConvertStructToClass(context, *src_struct_type, *target_class_type,
value_id, target);
}
Expand All @@ -716,7 +711,7 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
type_ids.push_back(ExprAsType(context, parse_node, tuple_inst_id));
}
auto tuple_type_id = context.CanonicalizeTupleType(parse_node, type_ids);
return sem_ir.GetTypeAllowBuiltinTypes(tuple_type_id);
return sem_ir.types().GetInstId(tuple_type_id);
}

// `{}` converts to `{} as type`.
Expand All @@ -725,7 +720,7 @@ static auto PerformBuiltinConversion(Context& context, Parse::NodeId parse_node,
if (auto struct_literal = value.TryAs<SemIR::StructLiteral>();
struct_literal &&
struct_literal->elements_id == SemIR::InstBlockId::Empty) {
value_id = sem_ir.GetTypeAllowBuiltinTypes(value_type_id);
value_id = sem_ir.types().GetInstId(value_type_id);
}
}

Expand Down
10 changes: 3 additions & 7 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ auto HandleClassDefinitionStart(Context& context, Parse::NodeId parse_node)
context.PushScope(class_decl_id, class_info.scope_id);

// Introduce `Self`.
context.AddNameToLookup(
parse_node, SemIR::NameId::SelfType,
context.sem_ir().GetTypeAllowBuiltinTypes(class_info.self_type_id));
context.AddNameToLookup(parse_node, SemIR::NameId::SelfType,
context.types().GetInstId(class_info.self_type_id));

context.inst_block_stack().Push();
context.node_stack().Push(parse_node, class_id);
Expand Down Expand Up @@ -227,10 +226,7 @@ auto HandleBaseDecl(Context& context, Parse::NodeId parse_node) -> bool {
// declaration as being final classes.
// TODO: Once we have a better idea of which types are considered to be
// classes, produce a better diagnostic for deriving from a non-class type.
auto base_class =
context.insts()
.Get(context.sem_ir().GetTypeAllowBuiltinTypes(base_type_id))
.TryAs<SemIR::ClassType>();
auto base_class = context.types().TryGetAs<SemIR::ClassType>(base_type_id);
if (!base_class ||
context.classes().Get(base_class->class_id).inheritance_kind ==
SemIR::Class::Final) {
Expand Down
3 changes: 1 addition & 2 deletions toolchain/check/handle_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ auto HandleIndexExpr(Context& context, Parse::NodeId parse_node) -> bool {
operand_inst_id = ConvertToValueOrRefExpr(context, operand_inst_id);
auto operand_inst = context.insts().Get(operand_inst_id);
auto operand_type_id = operand_inst.type_id();
auto operand_type_inst = context.insts().Get(
context.sem_ir().GetTypeAllowBuiltinTypes(operand_type_id));
auto operand_type_inst = context.types().GetAsInst(operand_type_id);

switch (operand_type_inst.kind()) {
case SemIR::ArrayType::Kind: {
Expand Down
13 changes: 5 additions & 8 deletions toolchain/check/handle_name.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static auto GetExprValueForLookupResult(Context& context,
// If lookup finds a class declaration, the value is its `Self` type.
auto lookup_result = context.insts().Get(lookup_result_id);
if (auto class_decl = lookup_result.TryAs<SemIR::ClassDecl>()) {
return context.sem_ir().GetTypeAllowBuiltinTypes(
return context.types().GetInstId(
context.classes().Get(class_decl->class_id).self_type_id);
}

Expand Down Expand Up @@ -108,10 +108,8 @@ auto HandleMemberAccessExpr(Context& context, Parse::NodeId parse_node)
base_id = ConvertToValueOrRefExpr(context, base_id);
base_type_id = context.insts().Get(base_id).type_id();

auto base_type = context.insts().Get(
context.sem_ir().GetTypeAllowBuiltinTypes(base_type_id));

switch (base_type.kind()) {
switch (auto base_type = context.types().GetAsInst(base_type_id);
base_type.kind()) {
case SemIR::ClassType::Kind: {
// Perform lookup for the name in the class scope.
auto class_scope_id = context.classes()
Expand All @@ -123,10 +121,9 @@ auto HandleMemberAccessExpr(Context& context, Parse::NodeId parse_node)

// Perform instance binding if we found an instance member.
auto member_type_id = context.insts().Get(member_id).type_id();
auto member_type_inst = context.insts().Get(
context.sem_ir().GetTypeAllowBuiltinTypes(member_type_id));
if (auto unbound_element_type =
member_type_inst.TryAs<SemIR::UnboundElementType>()) {
context.types().TryGetAs<SemIR::UnboundElementType>(
member_type_id)) {
// TODO: Check that the unbound element type describes a member of this
// class. Perform a conversion of the base if necessary.

Expand Down
5 changes: 2 additions & 3 deletions toolchain/check/handle_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ auto HandlePrefixOperator(Context& context, Parse::NodeId parse_node) -> bool {
value_id = ConvertToValueExpr(context, value_id);
auto type_id =
context.GetUnqualifiedType(context.insts().Get(value_id).type_id());
auto type_inst = context.insts().Get(
context.sem_ir().GetTypeAllowBuiltinTypes(type_id));
auto result_type_id = SemIR::TypeId::Error;
if (auto pointer_type = type_inst.TryAs<SemIR::PointerType>()) {
if (auto pointer_type =
context.types().TryGetAs<SemIR::PointerType>(type_id)) {
result_type_id = pointer_type->pointee_id;
} else if (type_id != SemIR::TypeId::Error) {
CARBON_DIAGNOSTIC(
Expand Down
2 changes: 1 addition & 1 deletion toolchain/lower/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ auto HandleIntLiteral(FunctionContext& context, SemIR::InstId inst_id,

auto HandleNameRef(FunctionContext& context, SemIR::InstId inst_id,
SemIR::NameRef inst) -> void {
auto type_inst_id = context.sem_ir().GetTypeAllowBuiltinTypes(inst.type_id);
auto type_inst_id = context.sem_ir().types().GetInstId(inst.type_id);
if (type_inst_id == SemIR::InstId::BuiltinNamespaceType) {
return;
}
Expand Down
11 changes: 3 additions & 8 deletions toolchain/lower/handle_aggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ static auto GetStructFieldName(FunctionContext& context,
SemIR::ElementIndex index) -> llvm::StringRef {
auto fields = context.sem_ir().inst_blocks().Get(
context.sem_ir()
.insts()
.GetAs<SemIR::StructType>(
context.sem_ir().types().Get(struct_type_id).inst_id)
.types()
.GetAs<SemIR::StructType>(struct_type_id)
.fields_id);
auto field = context.sem_ir().insts().GetAs<SemIR::StructTypeField>(
fields[index.index]);
Expand All @@ -105,11 +104,7 @@ auto HandleClassElementAccess(FunctionContext& context, SemIR::InstId inst_id,
// Find the class that we're performing access into.
auto class_type_id = context.sem_ir().insts().Get(inst.base_id).type_id();
auto class_id =
context.sem_ir()
.insts()
.GetAs<SemIR::ClassType>(
context.sem_ir().GetTypeAllowBuiltinTypes(class_type_id))
.class_id;
context.sem_ir().types().GetAs<SemIR::ClassType>(class_type_id).class_id;
const auto& class_info = context.sem_ir().classes().Get(class_id);

// Translate the class field access into a struct access on the object
Expand Down
12 changes: 12 additions & 0 deletions toolchain/sem_ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ cc_library(
":ids",
":inst",
":inst_kind",
":type_info",
":value_stores",
"//common:check",
"//toolchain/base:value_store",
Expand Down Expand Up @@ -101,12 +102,23 @@ cc_library(
],
)

cc_library(
name = "type_info",
hdrs = ["type_info.h"],
deps = [
":ids",
":inst",
"//common:ostream",
],
)

cc_library(
name = "value_stores",
srcs = ["value_stores.cpp"],
hdrs = ["value_stores.h"],
deps = [
":inst",
":type_info",
"//toolchain/base:value_store",
"//toolchain/base:yaml",
"//toolchain/lex:token_kind",
Expand Down
16 changes: 7 additions & 9 deletions toolchain/sem_ir/file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ static auto GetTypePrecedence(InstKind kind) -> int {
}

auto File::StringifyType(TypeId type_id) const -> std::string {
return StringifyTypeExpr(GetTypeAllowBuiltinTypes(type_id));
return StringifyTypeExpr(types().GetInstId(type_id));
}

auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
Expand Down Expand Up @@ -280,7 +280,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
out << "[";
steps.push_back(step.Next());
steps.push_back(
{.inst_id = GetTypeAllowBuiltinTypes(array.element_type_id)});
{.inst_id = types().GetInstId(array.element_type_id)});
} else if (step.index == 1) {
out << "; " << GetArrayBoundValue(array.bound_id) << "]";
}
Expand All @@ -298,7 +298,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {

// Add parentheses if required.
auto inner_type_inst_id =
GetTypeAllowBuiltinTypes(inst.As<ConstType>().inner_id);
types().GetInstId(inst.As<ConstType>().inner_id);
if (GetTypePrecedence(insts().Get(inner_type_inst_id).kind()) <
GetTypePrecedence(inst.kind())) {
out << "(";
Expand All @@ -318,7 +318,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
case PointerType::Kind: {
if (step.index == 0) {
steps.push_back(step.Next());
steps.push_back({.inst_id = GetTypeAllowBuiltinTypes(
steps.push_back({.inst_id = types().GetInstId(
inst.As<PointerType>().pointee_id)});
} else if (step.index == 1) {
out << "*";
Expand Down Expand Up @@ -346,8 +346,7 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
case StructTypeField::Kind: {
auto field = inst.As<StructTypeField>();
out << "." << names().GetFormatted(field.name_id) << ": ";
steps.push_back(
{.inst_id = GetTypeAllowBuiltinTypes(field.field_type_id)});
steps.push_back({.inst_id = types().GetInstId(field.field_type_id)});
break;
}
case TupleType::Kind: {
Expand All @@ -369,15 +368,14 @@ auto File::StringifyTypeExpr(InstId outer_inst_id) const -> std::string {
break;
}
steps.push_back(step.Next());
steps.push_back(
{.inst_id = GetTypeAllowBuiltinTypes(refs[step.index])});
steps.push_back({.inst_id = types().GetInstId(refs[step.index])});
break;
}
case UnboundElementType::Kind: {
if (step.index == 0) {
out << "<unbound element of class ";
steps.push_back(step.Next());
steps.push_back({.inst_id = GetTypeAllowBuiltinTypes(
steps.push_back({.inst_id = types().GetInstId(
inst.As<UnboundElementType>().class_type_id)});
} else {
out << ">";
Expand Down
Loading

0 comments on commit cef7eb5

Please sign in to comment.