Skip to content

Commit 3810633

Browse files
authored
JIT: handle interaction of OSR, PGO, and tail calls (#62263)
When both OSR and PGO are enabled, the jit will add PGO probes to OSR methods. And if the OSR method also has a tail call, the jit must take care to not add block probes to any return block reachable from possible tail call blocks. Instead, instrumentation should create copies of the return block probe in each return block predecessor (possibly splitting critical edges to make this viable). Because all this happens early on, there are no pred lists. The analysis leverages cheap preds instead, which means it needs to handle cases where a given pred has multiple pred list entries. And it must also be aware that the OSR method's actual flowgraph is a subgraph of the full initial graph. This came up while scouting what it would take to enable OSR by default. See #61934.
1 parent c8cd6fe commit 3810633

File tree

9 files changed

+404
-35
lines changed

9 files changed

+404
-35
lines changed

src/coreclr/jit/block.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,8 @@ enum BasicBlockFlags : unsigned __int64
552552
BBF_PATCHPOINT = MAKE_BBFLAG(36), // Block is a patchpoint
553553
BBF_HAS_CLASS_PROFILE = MAKE_BBFLAG(37), // BB contains a call needing a class profile
554554
BBF_PARTIAL_COMPILATION_PATCHPOINT = MAKE_BBFLAG(38), // Block is a partial compilation patchpoint
555-
BBF_HAS_ALIGN = MAKE_BBFLAG(39), // BB ends with 'align' instruction
555+
BBF_HAS_ALIGN = MAKE_BBFLAG(39), // BB ends with 'align' instruction
556+
BBF_TAILCALL_SUCCESSOR = MAKE_BBFLAG(40), // BB has pred that has potential tail call
556557

557558
// The following are sets of flags.
558559

src/coreclr/jit/compiler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7360,6 +7360,7 @@ class Compiler
73607360
#define OMF_NEEDS_GCPOLLS 0x00000200 // Method needs GC polls
73617361
#define OMF_HAS_FROZEN_STRING 0x00000400 // Method has a frozen string (REF constant int), currently only on CoreRT.
73627362
#define OMF_HAS_PARTIAL_COMPILATION_PATCHPOINT 0x00000800 // Method contains partial compilation patchpoints
7363+
#define OMF_HAS_TAILCALL_SUCCESSOR 0x00001000 // Method has potential tail call in a non BBJ_RETURN block
73637364

73647365
bool doesMethodHaveFatPointer()
73657366
{

src/coreclr/jit/fgbasic.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,10 @@ void Compiler::fgReplaceSwitchJumpTarget(BasicBlock* blockSwitch, BasicBlock* ne
536536
// Notes:
537537
// 1. Only branches are changed: BBJ_ALWAYS, the non-fallthrough path of BBJ_COND, BBJ_SWITCH, etc.
538538
// We ignore other block types.
539-
// 2. Only the first target found is updated. If there are multiple ways for a block
540-
// to reach 'oldTarget' (e.g., multiple arms of a switch), only the first one found is changed.
539+
// 2. All branch targets found are updated. If there are multiple ways for a block
540+
// to reach 'oldTarget' (e.g., multiple arms of a switch), all of them are changed.
541541
// 3. The predecessor lists are not changed.
542-
// 4. The switch table "unique successor" cache is invalidated.
542+
// 4. If any switch table entry was updated, the switch table "unique successor" cache is invalidated.
543543
//
544544
// This function is most useful early, before the full predecessor lists have been computed.
545545
//
@@ -569,20 +569,26 @@ void Compiler::fgReplaceJumpTarget(BasicBlock* block, BasicBlock* newTarget, Bas
569569
break;
570570

571571
case BBJ_SWITCH:
572-
unsigned jumpCnt;
573-
jumpCnt = block->bbJumpSwt->bbsCount;
574-
BasicBlock** jumpTab;
575-
jumpTab = block->bbJumpSwt->bbsDstTab;
572+
{
573+
unsigned const jumpCnt = block->bbJumpSwt->bbsCount;
574+
BasicBlock** const jumpTab = block->bbJumpSwt->bbsDstTab;
575+
bool changed = false;
576576

577577
for (unsigned i = 0; i < jumpCnt; i++)
578578
{
579579
if (jumpTab[i] == oldTarget)
580580
{
581581
jumpTab[i] = newTarget;
582-
break;
582+
changed = true;
583583
}
584584
}
585+
586+
if (changed)
587+
{
588+
InvalidateUniqueSwitchSuccMap();
589+
}
585590
break;
591+
}
586592

587593
default:
588594
assert(!"Block doesn't have a valid bbJumpKind!!!!");

src/coreclr/jit/fgprofile.cpp

Lines changed: 194 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,160 @@ void BlockCountInstrumentor::Prepare(bool preImport)
367367
return;
368368
}
369369

370+
// If this is an OSR method, look for potential tail calls in
371+
// blocks that are not BBJ_RETURN.
372+
//
373+
// If we see any, we need to adjust our instrumentation pattern.
374+
//
375+
if (m_comp->opts.IsOSR() && ((m_comp->optMethodFlags & OMF_HAS_TAILCALL_SUCCESSOR) != 0))
376+
{
377+
JITDUMP("OSR + PGO + potential tail call --- preparing to relocate block probes\n");
378+
379+
// We should be in a root method compiler instance. OSR + PGO does not
380+
// currently try and instrument inlinees.
381+
//
382+
// Relaxing this will require changes below because inlinee compilers
383+
// share the root compiler flow graph (and hence bb epoch), and flow
384+
// from inlinee tail calls to returns can be more complex.
385+
//
386+
assert(!m_comp->compIsForInlining());
387+
388+
// Build cheap preds.
389+
//
390+
m_comp->fgComputeCheapPreds();
391+
m_comp->EnsureBasicBlockEpoch();
392+
393+
// Keep track of return blocks needing special treatment.
394+
// We also need to track of duplicate preds.
395+
//
396+
JitExpandArrayStack<BasicBlock*> specialReturnBlocks(m_comp->getAllocator(CMK_Pgo));
397+
BlockSet predsSeen = BlockSetOps::MakeEmpty(m_comp);
398+
399+
// Walk blocks looking for BBJ_RETURNs that are successors of potential tail calls.
400+
//
401+
// If any such has a conditional pred, we will need to reroute flow from those preds
402+
// via an intermediary block. That block will subsequently hold the relocated block
403+
// probe for the return for those preds.
404+
//
405+
// Scrub the cheap pred list for these blocks so that each pred appears at most once.
406+
//
407+
for (BasicBlock* const block : m_comp->Blocks())
408+
{
409+
// Ignore blocks that we won't process.
410+
//
411+
if (!ShouldProcess(block))
412+
{
413+
continue;
414+
}
415+
416+
if ((block->bbFlags & BBF_TAILCALL_SUCCESSOR) != 0)
417+
{
418+
JITDUMP("Return " FMT_BB " is successor of possible tail call\n", block->bbNum);
419+
assert(block->bbJumpKind == BBJ_RETURN);
420+
bool pushed = false;
421+
BlockSetOps::ClearD(m_comp, predsSeen);
422+
for (BasicBlockList* predEdge = block->bbCheapPreds; predEdge != nullptr; predEdge = predEdge->next)
423+
{
424+
BasicBlock* const pred = predEdge->block;
425+
426+
// If pred is not to be processed, ignore it and scrub from the pred list.
427+
//
428+
if (!ShouldProcess(pred))
429+
{
430+
JITDUMP(FMT_BB " -> " FMT_BB " is dead edge\n", pred->bbNum, block->bbNum);
431+
predEdge->block = nullptr;
432+
continue;
433+
}
434+
435+
BasicBlock* const succ = pred->GetUniqueSucc();
436+
437+
if (succ == nullptr)
438+
{
439+
// Flow from pred -> block is conditional, and will require updating.
440+
//
441+
JITDUMP(FMT_BB " -> " FMT_BB " is critical edge\n", pred->bbNum, block->bbNum);
442+
if (!pushed)
443+
{
444+
specialReturnBlocks.Push(block);
445+
pushed = true;
446+
}
447+
448+
// Have we seen this pred before?
449+
//
450+
if (BlockSetOps::IsMember(m_comp, predsSeen, pred->bbNum))
451+
{
452+
// Yes, null out the duplicate pred list entry.
453+
//
454+
predEdge->block = nullptr;
455+
}
456+
}
457+
else
458+
{
459+
// We should only ever see one reference to this pred.
460+
//
461+
assert(!BlockSetOps::IsMember(m_comp, predsSeen, pred->bbNum));
462+
463+
// Ensure flow from non-critical preds is BBJ_ALWAYS as we
464+
// may add a new block right before block.
465+
//
466+
if (pred->bbJumpKind == BBJ_NONE)
467+
{
468+
pred->bbJumpKind = BBJ_ALWAYS;
469+
pred->bbJumpDest = block;
470+
}
471+
assert(pred->bbJumpKind == BBJ_ALWAYS);
472+
}
473+
474+
BlockSetOps::AddElemD(m_comp, predsSeen, pred->bbNum);
475+
}
476+
}
477+
}
478+
479+
// Now process each special return block.
480+
// Create an intermediary that falls through to the return.
481+
// Update any critical edges to target the intermediary.
482+
//
483+
// Note we could also route any non-tail-call pred via the
484+
// intermedary. Doing so would cut down on probe duplication.
485+
//
486+
while (specialReturnBlocks.Size() > 0)
487+
{
488+
bool first = true;
489+
BasicBlock* const block = specialReturnBlocks.Pop();
490+
BasicBlock* const intermediary = m_comp->fgNewBBbefore(BBJ_NONE, block, /* extendRegion*/ true);
491+
492+
intermediary->bbFlags |= BBF_IMPORTED;
493+
intermediary->inheritWeight(block);
494+
495+
for (BasicBlockList* predEdge = block->bbCheapPreds; predEdge != nullptr; predEdge = predEdge->next)
496+
{
497+
BasicBlock* const pred = predEdge->block;
498+
499+
if (pred != nullptr)
500+
{
501+
BasicBlock* const succ = pred->GetUniqueSucc();
502+
503+
if (succ == nullptr)
504+
{
505+
// This will update all branch targets from pred.
506+
//
507+
m_comp->fgReplaceJumpTarget(pred, intermediary, block);
508+
509+
// Patch the pred list. Note we only need one pred list
510+
// entry pointing at intermediary.
511+
//
512+
predEdge->block = first ? intermediary : nullptr;
513+
first = false;
514+
}
515+
else
516+
{
517+
assert(pred->bbJumpKind == BBJ_ALWAYS);
518+
}
519+
}
520+
}
521+
}
522+
}
523+
370524
#ifdef DEBUG
371525
// Set schema index to invalid value
372526
//
@@ -449,7 +603,37 @@ void BlockCountInstrumentor::Instrument(BasicBlock* block, Schema& schema, uint8
449603
GenTree* lhsNode = m_comp->gtNewIndOfIconHandleNode(typ, addrOfCurrentExecutionCount, GTF_ICON_BBC_PTR, false);
450604
GenTree* asgNode = m_comp->gtNewAssignNode(lhsNode, rhsNode);
451605

452-
m_comp->fgNewStmtAtBeg(block, asgNode);
606+
if ((block->bbFlags & BBF_TAILCALL_SUCCESSOR) != 0)
607+
{
608+
// We should have built and updated cheap preds during the prepare stage.
609+
//
610+
assert(m_comp->fgCheapPredsValid);
611+
612+
// Instrument each predecessor.
613+
//
614+
bool first = true;
615+
for (BasicBlockList* predEdge = block->bbCheapPreds; predEdge != nullptr; predEdge = predEdge->next)
616+
{
617+
BasicBlock* const pred = predEdge->block;
618+
619+
// We may have scrubbed cheap pred list duplicates during Prepare.
620+
//
621+
if (pred != nullptr)
622+
{
623+
JITDUMP("Placing copy of block probe for " FMT_BB " in pred " FMT_BB "\n", block->bbNum, pred->bbNum);
624+
if (!first)
625+
{
626+
asgNode = m_comp->gtCloneExpr(asgNode);
627+
}
628+
m_comp->fgNewStmtAtBeg(pred, asgNode);
629+
first = false;
630+
}
631+
}
632+
}
633+
else
634+
{
635+
m_comp->fgNewStmtAtBeg(block, asgNode);
636+
}
453637

454638
m_instrCount++;
455639
}
@@ -589,7 +773,7 @@ void Compiler::WalkSpanningTree(SpanningTreeVisitor* visitor)
589773
// graph. So for BlockSets and NumSucc, we use the root compiler instance.
590774
//
591775
Compiler* const comp = impInlineRoot();
592-
comp->NewBasicBlockEpoch();
776+
comp->EnsureBasicBlockEpoch();
593777

594778
// We will track visited or queued nodes with a bit vector.
595779
//
@@ -1612,7 +1796,7 @@ PhaseStatus Compiler::fgPrepareToInstrumentMethod()
16121796
else
16131797
{
16141798
JITDUMP("Using block profiling, because %s\n",
1615-
(JitConfig.JitEdgeProfiling() > 0)
1799+
(JitConfig.JitEdgeProfiling() == 0)
16161800
? "edge profiles disabled"
16171801
: prejit ? "prejitting" : osrMethod ? "OSR" : "tier0 with patchpoints");
16181802

@@ -1793,6 +1977,13 @@ PhaseStatus Compiler::fgInstrumentMethod()
17931977
fgCountInstrumentor->InstrumentMethodEntry(schema, profileMemory);
17941978
fgClassInstrumentor->InstrumentMethodEntry(schema, profileMemory);
17951979

1980+
// If we needed to create cheap preds, we're done with them now.
1981+
//
1982+
if (fgCheapPredsValid)
1983+
{
1984+
fgRemovePreds();
1985+
}
1986+
17961987
return PhaseStatus::MODIFIED_EVERYTHING;
17971988
}
17981989

src/coreclr/jit/importer.cpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9797,34 +9797,49 @@ var_types Compiler::impImportCall(OPCODE opcode,
97979797
}
97989798

97999799
// A tail recursive call is a potential loop from the current block to the start of the method.
9800-
if ((tailCallFlags != 0) && canTailCall && gtIsRecursiveCall(methHnd))
9800+
if ((tailCallFlags != 0) && canTailCall)
98019801
{
9802-
assert(verCurrentState.esStackDepth == 0);
9803-
BasicBlock* loopHead = nullptr;
9804-
if (opts.IsOSR())
9802+
// If a root method tail call candidate block is not a BBJ_RETURN, it should have a unique
9803+
// BBJ_RETURN successor. Mark that successor so we can handle it specially during profile
9804+
// instrumentation.
9805+
//
9806+
if (!compIsForInlining() && (compCurBB->bbJumpKind != BBJ_RETURN))
98059807
{
9806-
// We might not have been planning on importing the method
9807-
// entry block, but now we must.
9808-
9809-
// We should have remembered the real method entry block.
9810-
assert(fgEntryBB != nullptr);
9811-
9812-
JITDUMP("\nOSR: found tail recursive call in the method, scheduling " FMT_BB " for importation\n",
9813-
fgEntryBB->bbNum);
9814-
impImportBlockPending(fgEntryBB);
9815-
loopHead = fgEntryBB;
9808+
BasicBlock* const successor = compCurBB->GetUniqueSucc();
9809+
assert(successor->bbJumpKind == BBJ_RETURN);
9810+
successor->bbFlags |= BBF_TAILCALL_SUCCESSOR;
9811+
optMethodFlags |= OMF_HAS_TAILCALL_SUCCESSOR;
98169812
}
9817-
else
9813+
9814+
if (gtIsRecursiveCall(methHnd))
98189815
{
9819-
// For normal jitting we'll branch back to the firstBB; this
9820-
// should already be imported.
9821-
loopHead = fgFirstBB;
9822-
}
9816+
assert(verCurrentState.esStackDepth == 0);
9817+
BasicBlock* loopHead = nullptr;
9818+
if (opts.IsOSR())
9819+
{
9820+
// We might not have been planning on importing the method
9821+
// entry block, but now we must.
98239822

9824-
JITDUMP("\nFound tail recursive call in the method. Mark " FMT_BB " to " FMT_BB
9825-
" as having a backward branch.\n",
9826-
loopHead->bbNum, compCurBB->bbNum);
9827-
fgMarkBackwardJump(loopHead, compCurBB);
9823+
// We should have remembered the real method entry block.
9824+
assert(fgEntryBB != nullptr);
9825+
9826+
JITDUMP("\nOSR: found tail recursive call in the method, scheduling " FMT_BB " for importation\n",
9827+
fgEntryBB->bbNum);
9828+
impImportBlockPending(fgEntryBB);
9829+
loopHead = fgEntryBB;
9830+
}
9831+
else
9832+
{
9833+
// For normal jitting we'll branch back to the firstBB; this
9834+
// should already be imported.
9835+
loopHead = fgFirstBB;
9836+
}
9837+
9838+
JITDUMP("\nFound tail recursive call in the method. Mark " FMT_BB " to " FMT_BB
9839+
" as having a backward branch.\n",
9840+
loopHead->bbNum, compCurBB->bbNum);
9841+
fgMarkBackwardJump(loopHead, compCurBB);
9842+
}
98289843
}
98299844

98309845
// Note: we assume that small return types are already normalized by the managed callee

0 commit comments

Comments
 (0)