Skip to content

Commit

Permalink
Always pass function as first argument to JITed functions
Browse files Browse the repository at this point in the history
Summary:
This modifies our native signature to always pass the Python function in `rdi` as the first argument.  This will open up the possibility of using the function definitions inside of the v-table so that we can go directly to the static entry point on method invokes and by pass the type checks.  The goal here is to get rid of the `Ci_Py_TPFLAGS_IS_STATICALLY_DEFINED` flag.

It also means we no longer have the restriction that we can't invoke functions which have closures.  This probably doesn't matter much practically as the only case where that would happen today is when a method contains `super()`.  Otherwise we have a nested function which we're not going to ever do an `INVOKE_FUNCTION` against (as there's no way to resolve the inner function by name).

Reviewed By: carljm

Differential Revision: D44230044

fbshipit-source-id: c3c64327180eedfb11cf13ab6361f9620dc649ac
  • Loading branch information
DinoV authored and facebook-github-bot committed Mar 29, 2023
1 parent ca8f4a5 commit d4a8494
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 64 deletions.
18 changes: 10 additions & 8 deletions Jit/codegen/gen_asm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ void* NativeGenerator::getVectorcallEntry() {
// arguments.
const std::vector<TypedArgument>& checks = GetFunction()->typed_args;

for (size_t i = 0, check_index = 0, gp_index = 0, fp_index = 0;
// gp_index starts at 1 because the first argument is reserved for the
// function
for (size_t i = 0, check_index = 0, gp_index = 1, fp_index = 0;
i < static_cast<size_t>(GetFunction()->numArgs());
i++) {
auto add_gp = [&]() {
Expand Down Expand Up @@ -842,7 +844,7 @@ void NativeGenerator::generatePrologue(
auto frame_cursor = as_->cursor();
as_->bind(setup_frame);

constexpr auto kFuncPtrReg = x86::rax;
constexpr auto kFuncPtrReg = x86::rdi;
constexpr auto kArgsReg = x86::r10;
constexpr auto kArgsPastSixReg = kArgsReg;

Expand Down Expand Up @@ -873,7 +875,7 @@ void NativeGenerator::generatePrologue(
// deal with loading them from here...
as_->lea(
kArgsPastSixReg,
x86::ptr(kArgsReg, ARGUMENT_REGS.size() * sizeof(void*)));
x86::ptr(kArgsReg, (ARGUMENT_REGS.size() - 1) * sizeof(void*)));
}

// Finally allocate the saved space required for the actual function
Expand Down Expand Up @@ -1240,6 +1242,7 @@ void NativeGenerator::generateStaticEntryPoint(
std::vector<std::pair<const x86::Reg&, const x86::Reg&>> save_regs;

if (!isGen()) {
save_regs.emplace_back(x86::rdi, x86::rdi);
for (size_t i = 0, check_index = 0, arg_index = 0, fp_index = 0;
i < total_args;
i++) {
Expand Down Expand Up @@ -1279,8 +1282,8 @@ void NativeGenerator::generateStaticEntryPoint(
}
}

if (arg_index < ARGUMENT_REGS.size()) {
switch (ARGUMENT_REGS[arg_index++]) {
if (arg_index + 1 < ARGUMENT_REGS.size()) {
switch (ARGUMENT_REGS[++arg_index]) {
case PhyLocation::RDI:
save_regs.emplace_back(x86::rdi, x86::rdi);
break;
Expand Down Expand Up @@ -1308,7 +1311,7 @@ void NativeGenerator::generateStaticEntryPoint(

loadOrGenerateLinkFrame(x86::r11, save_regs);

if (total_args > ARGUMENT_REGS.size()) {
if (total_args + 1 > ARGUMENT_REGS.size()) {
as_->lea(x86::r10, x86::ptr(x86::rbp, 16));
}
as_->jmp(native_entry_point);
Expand All @@ -1324,8 +1327,7 @@ void NativeGenerator::generateStaticEntryPoint(

bool NativeGenerator::hasStaticEntry() const {
PyCodeObject* code = GetFunction()->code;
return (code->co_flags & CO_STATICALLY_COMPILED) &&
!GetFunction()->uses_runtime_func;
return (code->co_flags & CO_STATICALLY_COMPILED);
}

void NativeGenerator::generateCode(CodeHolder& codeholder) {
Expand Down
12 changes: 8 additions & 4 deletions Jit/hir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1975,7 +1975,7 @@ bool HIRBuilder::emitInvokeFunction(
Register* funcreg = temps_.AllocateStack();
if (target.container_is_immutable) {
// try to emit a direct x64 call (InvokeStaticFunction/CallStatic) if we can
if (!target.uses_runtime_func) {

if (target.is_function && target.is_statically_typed) {
if (_PyJIT_CompileFunction(target.func()) == PYJIT_RESULT_RETRY) {
JIT_DLOG(
Expand All @@ -1988,11 +1988,16 @@ bool HIRBuilder::emitInvokeFunction(
// it'll just have an extra indirection if not JIT compiled.
Register* out = temps_.AllocateStack();
Type typ = target.return_type;
tc.emit<LoadConst>(funcreg, Type::fromObject(target.callable));

auto call =
tc.emit<InvokeStaticFunction>(nargs, out, target.func(), typ);
tc.emit<InvokeStaticFunction>(nargs + 1, out, target.func(), typ);

call->SetOperand(0, funcreg);

for (auto i = nargs - 1; i >= 0; i--) {
Register* operand = tc.frame.stack.pop();
call->SetOperand(i, operand);
call->SetOperand(i + 1, operand);
}
call->setFrameState(tc.frame);

Expand All @@ -2003,7 +2008,6 @@ bool HIRBuilder::emitInvokeFunction(
target.is_builtin && tryEmitDirectMethodCall(target, tc, nargs)) {
return false;
}
}

// we couldn't emit an x64 call, but we know what object we'll vectorcall,
// so load it directly
Expand Down
4 changes: 2 additions & 2 deletions Jit/hir/optimization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ struct AbstractCall {
Register* arg(std::size_t i) const {
if (instr->opcode() == Opcode::kInvokeStaticFunction) {
auto f = dynamic_cast<InvokeStaticFunction*>(instr);
return f->arg(i);
return f->arg(i + 1);
}
if (auto f = dynamic_cast<VectorCallBase*>(instr)) {
return f->arg(i);
Expand Down Expand Up @@ -1090,7 +1090,7 @@ void InlineFunctionCalls::Run(Function& irfunc) {
} else if (instr.IsInvokeStaticFunction()) {
auto call = static_cast<InvokeStaticFunction*>(&instr);
to_inline.emplace_back(
AbstractCall(call->func(), call->NumArgs(), call));
AbstractCall(call->func(), call->NumArgs() - 1, call));
}
}
}
Expand Down
22 changes: 11 additions & 11 deletions Jit/jit_rt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,29 +1507,29 @@ JITRT_StaticCallReturn JITRT_FailedDeferredCompileShim(
}

// PyObject** args is:
// arg0
// arg1
// arg0 - function object
// arg1 - first real argument
// arg2
// arg3
// arg4
// arg5
// previous rbp
// return address to JITed code
// memory argument 0
// memory argument 0 (6th real argument)
// memory argument 1
// ...

PyObject** dest_args;
PyObject* final_args[total_args];
if (total_args <= 6) {
if (total_args <= 5) {
// no gap in args to worry about
dest_args = args;
dest_args = args + 1;
} else {
for (int i = 0; i < 6; i++) {
final_args[i] = args[i];
for (int i = 0; i < 5; i++) {
final_args[i] = args[i + 1];
}
for (int i = 6; i < total_args; i++) {
final_args[i] = args[i + 2];
for (int i = 5; i < total_args; i++) {
final_args[i] = args[i + 3];
}
dest_args = final_args;
}
Expand All @@ -1545,7 +1545,7 @@ JITRT_StaticCallReturn JITRT_FailedDeferredCompileShim(
for (Py_ssize_t i = 0; i < Py_SIZE(arg_info); i++) {
if (arg_info->tai_args[i].tai_primitive_type != -1) {
// primitive type, box...
int arg = arg_info->tai_args[i].tai_argnum;
int arg = arg_info->tai_args[i].tai_argnum + 1;
uint64_t arg_val;
if (arg >= 6) {
arg += 4;
Expand Down Expand Up @@ -1597,7 +1597,7 @@ JITRT_StaticCallReturn JITRT_FailedDeferredCompileShim(
// we can update the incoming arg array, either it's
// the pushed values on the stack by the trampoline, or
// it's final_args we allocated above.
dest_args[arg] = new_val;
dest_args[arg - 1] = new_val;
allocated_args[allocated_count++] = new_val;
}
}
Expand Down
7 changes: 2 additions & 5 deletions Jit/lir/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ BasicBlock* LIRGenerator::GenerateEntryBlock() {
bindVReg("__asm_extra_args", jit::codegen::PhyLocation::R10);
env_->asm_tstate = bindVReg("__asm_tstate", jit::codegen::PhyLocation::R11);
if (func_->uses_runtime_func) {
env_->asm_func = bindVReg("__asm_func", jit::codegen::PhyLocation::RAX);
env_->asm_func = bindVReg("__asm_func", jit::codegen::PhyLocation::RDI);
}

return block;
Expand Down Expand Up @@ -1842,10 +1842,7 @@ LIRGenerator::TranslatedBlock LIRGenerator::TranslateOneBasicBlock(
PyFunctionObject* func = instr->func();

std::stringstream ss;
JIT_CHECK(
!usesRuntimeFunc(func->func_code),
"Can't statically invoke given function: %s",
PyUnicode_AsUTF8(func->func_qualname));

if (_PyJIT_IsCompiled((PyObject*)func)) {
ss << fmt::format(
"Call {}, {}",
Expand Down
30 changes: 30 additions & 0 deletions Lib/test/test_compiler/test_static/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,33 @@ def g():
r"Missing value for positional-only arg 0",
at="f(x=1)",
)

def test_invoke_super_final(self):
"""tests invoke against a function which has a free var which
gets introduced due to the super() call"""
codestr = """
from typing import final
import sys
class B:
def f(self):
return 42
@final
class C(B):
def f(self):
return super().f()
def x(c: C):
return c.f()
"""

code = self.compile(codestr, modname="foo")
x = self.find_code(code, "x")
self.assertInBytecode(x, "INVOKE_FUNCTION", (("foo", "C", "f"), 1))

with self.in_strict_module(codestr) as mod:
c = mod.C()
self.assertEqual(mod.C.f.__code__.co_freevars, ("__class__",))
self.assertEqual(mod.x(c), 42)
self.assertEqual(mod.x(c), 42)
9 changes: 6 additions & 3 deletions RuntimeTests/hir_tests/hir_builder_static_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,8 @@ fun jittestmodule:test {
bb 0 {
v0 = LoadConst<MortalUnicodeExact["hello"]>
v1 = LoadConst<ImmortalLongExact[123]>
v3 = InvokeStaticFunction<jittestmodule.x, 2, Unicode> v0 v1 {
v2 = LoadConst<MortalFunc[function:0xdeadbeef]>
v3 = InvokeStaticFunction<jittestmodule.x, 3, Unicode> v2 v0 v1 {
FrameState {
NextInstrOffset 10
}
Expand All @@ -964,7 +965,8 @@ fun jittestmodule:test {
}
v1 = LoadConst<MortalUnicodeExact["hello"]>
v2 = LoadConst<ImmortalLongExact[123]>
v4 = InvokeStaticFunction<jittestmodule.x, 2, Object> v1 v2 {
v3 = LoadConst<MortalFunc[function:0xdeadbeef]>
v4 = InvokeStaticFunction<jittestmodule.x, 3, Object> v3 v1 v2 {
FrameState {
NextInstrOffset 12
}
Expand Down Expand Up @@ -1157,7 +1159,8 @@ fun jittestmodule:test {
Stack<5> v2 v3 v8 v13 v14
}
}
v20 = InvokeStaticFunction<jittestmodule.x, 5, Long> v2 v3 v8 v13 v18 {
v19 = LoadConst<MortalFunc[function:0xdeadbeef]>
v20 = InvokeStaticFunction<jittestmodule.x, 6, Long> v19 v2 v3 v8 v13 v18 {
FrameState {
NextInstrOffset 42
Locals<1> v0
Expand Down
44 changes: 25 additions & 19 deletions RuntimeTests/hir_tests/inliner_elimination_static_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def test():
---
fun jittestmodule:test {
bb 0 {
v5:ImmortalLongExact[4] = LoadConst<ImmortalLongExact[4]>
Return v5
v2:MortalFunc[function:0xdeadbeef] = LoadConst<MortalFunc[function:0xdeadbeef]>
v6:ImmortalLongExact[4] = LoadConst<ImmortalLongExact[4]>
Return v6
}
}
---
Expand All @@ -33,14 +34,16 @@ def test():
---
fun jittestmodule:test {
bb 0 {
v10:ImmortalLongExact[3] = LoadConst<ImmortalLongExact[3]>
v14:ImmortalLongExact[4] = LoadConst<ImmortalLongExact[4]>
UseType<LongExact> v10
UseType<LongExact> v14
UseType<ImmortalLongExact[3]> v10
UseType<ImmortalLongExact[4]> v14
v17:ImmortalLongExact[7] = LoadConst<ImmortalLongExact[7]>
Return v17
v5:MortalFunc[function:0xdeadbeef] = LoadConst<MortalFunc[function:0xdeadbeef]>
v12:ImmortalLongExact[3] = LoadConst<ImmortalLongExact[3]>
v7:MortalFunc[function:0xdeadbeef] = LoadConst<MortalFunc[function:0xdeadbeef]>
v16:ImmortalLongExact[4] = LoadConst<ImmortalLongExact[4]>
UseType<LongExact> v12
UseType<LongExact> v16
UseType<ImmortalLongExact[3]> v12
UseType<ImmortalLongExact[4]> v16
v19:ImmortalLongExact[7] = LoadConst<ImmortalLongExact[7]>
Return v19
}
}
---
Expand All @@ -56,17 +59,18 @@ fun jittestmodule:test {
bb 0 {
v4:ImmortalLongExact[3] = LoadConst<ImmortalLongExact[3]>
v5:MortalUnicodeExact["x"] = LoadConst<MortalUnicodeExact["x"]>
v6:MortalFunc[function:0xdeadbeef] = LoadConst<MortalFunc[function:0xdeadbeef]>
BeginInlinedFunction<jittestmodule:add> {
NextInstrOffset 10
}
v15:Object = BinaryOp<Add> v4 v5 {
v16:Object = BinaryOp<Add> v4 v5 {
FrameState {
NextInstrOffset 8
Locals<2> v4 v5
}
}
EndInlinedFunction
Return v15
Return v16
}
}
---
Expand All @@ -91,13 +95,15 @@ fun jittestmodule:test {
NextInstrOffset 4
}
}
v15:CInt8[4] = LoadConst<CInt8[4]>
v17:Nullptr = LoadConst<Nullptr>
UseType<CInt8> v15
StoreField<foo@16> v5 v15 v17
v19:NoneType = LoadConst<NoneType>
v26:CInt8 = LoadField<foo@16, CInt8, borrowed> v5
Return<CInt8> v26
v6:MortalFunc[function:0xdeadbeef] = LoadConst<MortalFunc[function:0xdeadbeef]>
v17:CInt8[4] = LoadConst<CInt8[4]>
v19:Nullptr = LoadConst<Nullptr>
UseType<CInt8> v17
StoreField<foo@16> v5 v17 v19
v21:NoneType = LoadConst<NoneType>
v8:MortalFunc[function:0xdeadbeef] = LoadConst<MortalFunc[function:0xdeadbeef]>
v28:CInt8 = LoadField<foo@16, CInt8, borrowed> v5
Return<CInt8> v28
}
}
---
Loading

0 comments on commit d4a8494

Please sign in to comment.