Skip to content

Commit ad85f90

Browse files
committed
lower if (x) ... else |e| switch (e) to switch_block_err_union
1 parent 002ec9c commit ad85f90

File tree

4 files changed

+679
-30
lines changed

4 files changed

+679
-30
lines changed

src/AstGen.zig

Lines changed: 145 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,18 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
838838

839839
.if_simple,
840840
.@"if",
841-
=> return ifExpr(gz, scope, ri.br(), node, tree.fullIf(node).?),
841+
=> {
842+
const if_full = tree.fullIf(node).?;
843+
if (if_full.error_token) |error_token| {
844+
const tag = node_tags[if_full.ast.else_expr];
845+
if ((tag == .@"switch" or tag == .switch_comma) and
846+
std.mem.eql(u8, tree.tokenSlice(error_token), tree.tokenSlice(error_token + 4)))
847+
{
848+
return switchExprErrUnion(gz, scope, ri.br(), node, .@"if");
849+
}
850+
}
851+
return ifExpr(gz, scope, ri.br(), node, if_full);
852+
},
842853

843854
.while_simple,
844855
.while_cont,
@@ -1019,7 +1030,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
10191030
token_tags[catch_token + 4] == .keyword_switch)
10201031
{
10211032
if (std.mem.eql(u8, tree.tokenSlice(catch_token + 2), tree.tokenSlice(catch_token + 6))) {
1022-
return switchExprErrUnion(gz, scope, ri.br(), node);
1033+
return switchExprErrUnion(gz, scope, ri.br(), node, .@"catch");
10231034
}
10241035
}
10251036
break :blk catch_token + 2;
@@ -6860,7 +6871,8 @@ fn switchExprErrUnion(
68606871
parent_gz: *GenZir,
68616872
scope: *Scope,
68626873
ri: ResultInfo,
6863-
catch_node: Ast.Node.Index,
6874+
catch_or_if_node: Ast.Node.Index,
6875+
node_ty: enum { @"catch", @"if" },
68646876
) InnerError!Zir.Inst.Ref {
68656877
const astgen = parent_gz.astgen;
68666878
const gpa = astgen.gpa;
@@ -6869,21 +6881,42 @@ fn switchExprErrUnion(
68696881
const node_tags = tree.nodes.items(.tag);
68706882
const main_tokens = tree.nodes.items(.main_token);
68716883
const token_tags = tree.tokens.items(.tag);
6872-
const operand_node = node_datas[catch_node].lhs;
6873-
const switch_node = node_datas[catch_node].rhs;
6884+
6885+
const if_full = switch (node_ty) {
6886+
.@"catch" => undefined,
6887+
.@"if" => tree.fullIf(catch_or_if_node).?,
6888+
};
6889+
6890+
const switch_node, const operand_node, const error_payload = switch (node_ty) {
6891+
.@"catch" => .{
6892+
node_datas[catch_or_if_node].rhs,
6893+
node_datas[catch_or_if_node].lhs,
6894+
main_tokens[catch_or_if_node] + 2,
6895+
},
6896+
.@"if" => .{
6897+
if_full.ast.else_expr,
6898+
if_full.ast.cond_expr,
6899+
if_full.error_token.?,
6900+
},
6901+
};
6902+
assert(node_tags[switch_node] == .@"switch" or node_tags[switch_node] == .switch_comma);
6903+
68746904
const extra = tree.extraData(node_datas[switch_node].rhs, Ast.Node.SubRange);
68756905
const case_nodes = tree.extra_data[extra.start..extra.end];
68766906

6877-
const need_rl = astgen.nodes_need_rl.contains(catch_node);
6907+
const need_rl = astgen.nodes_need_rl.contains(catch_or_if_node);
68786908
const block_ri: ResultInfo = if (need_rl) ri else .{
68796909
.rl = switch (ri.rl) {
6880-
.ptr => .{ .ty = (try ri.rl.resultType(parent_gz, catch_node)).? },
6910+
.ptr => .{ .ty = (try ri.rl.resultType(parent_gz, catch_or_if_node)).? },
68816911
.inferred_ptr => .none,
68826912
else => ri.rl,
68836913
},
68846914
.ctx = ri.ctx,
68856915
};
68866916

6917+
const payload_is_ref = node_ty == .@"if" and
6918+
if_full.payload_token != null and token_tags[if_full.payload_token.?] == .asterisk;
6919+
68876920
// We need to call `rvalue` to write through to the pointer only if we had a
68886921
// result pointer and aren't forwarding it.
68896922
const LocTag = @typeInfo(ResultInfo.Loc).Union.tag_type.?;
@@ -6951,12 +6984,15 @@ fn switchExprErrUnion(
69516984
}
69526985
}
69536986

6954-
const operand_ri: ResultInfo = .{ .rl = .none, .ctx = .error_handling_expr };
6987+
const operand_ri: ResultInfo = .{
6988+
.rl = if (payload_is_ref) .ref else .none,
6989+
.ctx = .error_handling_expr,
6990+
};
69556991

69566992
astgen.advanceSourceCursorToNode(operand_node);
69576993
const operand_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column };
69586994

6959-
const raw_operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node_datas[catch_node].rhs);
6995+
const raw_operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, switch_node);
69606996
const item_ri: ResultInfo = .{ .rl = .none };
69616997

69626998
// This contains the data that goes into the `extra` array for the SwitchBlockErrUnion, except
@@ -6997,13 +7033,93 @@ fn switchExprErrUnion(
69977033

69987034
try case_scope.addDbgBlockBegin();
69997035

7000-
const unwrapped_payload = try case_scope.addUnNode(.err_union_payload_unsafe, raw_operand, catch_node);
7001-
const case_result = switch (ri.rl) {
7002-
.ref, .ref_coerced_ty => unwrapped_payload,
7003-
else => try rvalue(&case_scope, block_scope.break_result_info, unwrapped_payload, catch_node),
7004-
};
7005-
try case_scope.addDbgBlockEnd();
7006-
_ = try case_scope.addBreakWithSrcNode(.@"break", switch_block, case_result, catch_node);
7036+
const unwrap_payload_tag: Zir.Inst.Tag = if (payload_is_ref)
7037+
.err_union_payload_unsafe_ptr
7038+
else
7039+
.err_union_payload_unsafe;
7040+
7041+
const unwrapped_payload = try case_scope.addUnNode(
7042+
unwrap_payload_tag,
7043+
raw_operand,
7044+
catch_or_if_node,
7045+
);
7046+
7047+
switch (node_ty) {
7048+
.@"catch" => {
7049+
const case_result = switch (ri.rl) {
7050+
.ref, .ref_coerced_ty => unwrapped_payload,
7051+
else => try rvalue(
7052+
&case_scope,
7053+
block_scope.break_result_info,
7054+
unwrapped_payload,
7055+
catch_or_if_node,
7056+
),
7057+
};
7058+
try case_scope.addDbgBlockEnd();
7059+
_ = try case_scope.addBreakWithSrcNode(
7060+
.@"break",
7061+
switch_block,
7062+
case_result,
7063+
catch_or_if_node,
7064+
);
7065+
},
7066+
.@"if" => {
7067+
var payload_val_scope: Scope.LocalVal = undefined;
7068+
7069+
try case_scope.addDbgBlockBegin();
7070+
const then_node = if_full.ast.then_expr;
7071+
const then_sub_scope = s: {
7072+
assert(if_full.error_token != null);
7073+
if (if_full.payload_token) |payload_token| {
7074+
const token_name_index = payload_token + @intFromBool(payload_is_ref);
7075+
const ident_name = try astgen.identAsString(token_name_index);
7076+
const token_name_str = tree.tokenSlice(token_name_index);
7077+
if (mem.eql(u8, "_", token_name_str))
7078+
break :s &case_scope.base;
7079+
try astgen.detectLocalShadowing(
7080+
&case_scope.base,
7081+
ident_name,
7082+
token_name_index,
7083+
token_name_str,
7084+
.capture,
7085+
);
7086+
payload_val_scope = .{
7087+
.parent = &case_scope.base,
7088+
.gen_zir = &case_scope,
7089+
.name = ident_name,
7090+
.inst = unwrapped_payload,
7091+
.token_src = payload_token,
7092+
.id_cat = .capture,
7093+
};
7094+
try case_scope.addDbgVar(.dbg_var_val, ident_name, unwrapped_payload);
7095+
break :s &payload_val_scope.base;
7096+
} else {
7097+
_ = try case_scope.addUnNode(
7098+
.ensure_err_union_payload_void,
7099+
raw_operand,
7100+
catch_or_if_node,
7101+
);
7102+
break :s &case_scope.base;
7103+
}
7104+
};
7105+
const then_result = try expr(
7106+
&case_scope,
7107+
then_sub_scope,
7108+
block_scope.break_result_info,
7109+
then_node,
7110+
);
7111+
try checkUsed(parent_gz, &case_scope.base, then_sub_scope);
7112+
if (!case_scope.endsWithNoReturn()) {
7113+
try case_scope.addDbgBlockEnd();
7114+
_ = try case_scope.addBreakWithSrcNode(
7115+
.@"break",
7116+
switch_block,
7117+
then_result,
7118+
then_node,
7119+
);
7120+
}
7121+
},
7122+
}
70077123

70087124
const case_slice = case_scope.instructionsSlice();
70097125
// Since we use the switch_block_err_union instruction itself to refer
@@ -7020,9 +7136,18 @@ fn switchExprErrUnion(
70207136
};
70217137
const body_len = refs_len + astgen.countBodyLenAfterFixups(case_slice);
70227138
try payloads.ensureUnusedCapacity(gpa, body_len);
7139+
const capture: Zir.Inst.SwitchBlock.ProngInfo.Capture = switch (node_ty) {
7140+
.@"catch" => .none,
7141+
.@"if" => if (if_full.payload_token == null)
7142+
.none
7143+
else if (payload_is_ref)
7144+
.by_ref
7145+
else
7146+
.by_val,
7147+
};
70237148
payloads.items[body_len_index] = @bitCast(Zir.Inst.SwitchBlock.ProngInfo{
70247149
.body_len = @intCast(body_len),
7025-
.capture = .none,
7150+
.capture = capture,
70267151
.is_inline = false,
70277152
.has_tag_capture = false,
70287153
});
@@ -7032,16 +7157,15 @@ fn switchExprErrUnion(
70327157
appendBodyWithFixupsArrayList(astgen, payloads, case_slice);
70337158
}
70347159

7035-
const err_name, const error_payload = blk: {
7036-
const error_payload = main_tokens[catch_node] + 2;
7160+
const err_name = blk: {
70377161
const err_str = tree.tokenSlice(error_payload);
70387162
if (mem.eql(u8, err_str, "_")) {
70397163
return astgen.failTok(error_payload, "discard of error capture; omit it instead", .{});
70407164
}
70417165
const err_name = try astgen.identAsString(error_payload);
70427166
try astgen.detectLocalShadowing(scope, err_name, error_payload, err_str, .capture);
70437167

7044-
break :blk .{ err_name, error_payload };
7168+
break :blk err_name;
70457169
};
70467170

70477171
// allocate a shared dummy instruction for the error capture
@@ -7234,6 +7358,7 @@ fn switchExprErrUnion(
72347358
.has_else = has_else,
72357359
.scalar_cases_len = @intCast(scalar_cases_len),
72367360
.any_uses_err_capture = any_uses_err_capture,
7361+
.payload_is_ref = payload_is_ref,
72377362
},
72387363
});
72397364

src/Sema.zig

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8908,10 +8908,14 @@ fn zirErrUnionCodePtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE
89088908
const tracy = trace(@src());
89098909
defer tracy.end();
89108910

8911-
const mod = sema.mod;
89128911
const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
89138912
const src = inst_data.src();
89148913
const operand = try sema.resolveInst(inst_data.operand);
8914+
return sema.analyzeErrUnionCodePtr(block, src, operand);
8915+
}
8916+
8917+
fn analyzeErrUnionCodePtr(sema: *Sema, block: *Block, src: LazySrcLoc, operand: Air.Inst.Ref) CompileError!Air.Inst.Ref {
8918+
const mod = sema.mod;
89158919
const operand_ty = sema.typeOf(operand);
89168920
assert(operand_ty.zigTypeTag(mod) == .Pointer);
89178921

@@ -8926,7 +8930,10 @@ fn zirErrUnionCodePtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE
89268930
if (try sema.resolveDefinedValue(block, src, operand)) |pointer_val| {
89278931
if (try sema.pointerDeref(block, src, pointer_val, operand_ty)) |val| {
89288932
assert(val.getErrorName(mod) != .none);
8929-
return Air.internedToRef(val.toIntern());
8933+
return Air.internedToRef((try mod.intern(.{ .err = .{
8934+
.ty = result_ty.toIntern(),
8935+
.name = mod.intern_pool.indexToKey(val.toIntern()).error_union.val.err_name,
8936+
} })));
89308937
}
89318938
}
89328939

@@ -11144,7 +11151,6 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1114411151
const extra = sema.code.extraData(Zir.Inst.SwitchBlockErrUnion, inst_data.payload_index);
1114511152

1114611153
const raw_operand_val = try sema.resolveInst(extra.data.operand);
11147-
assert(sema.typeOf(raw_operand_val).zigTypeTag(mod) == .ErrorUnion);
1114811154

1114911155
// AstGen guarantees that the instruction immediately preceding
1115011156
// switch_block_err_union is a dbg_stmt
@@ -11175,6 +11181,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1117511181
const NonError = struct {
1117611182
body: []const Zir.Inst.Index,
1117711183
end: usize,
11184+
capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
1117811185
};
1117911186

1118011187
const non_error_case: NonError = non_error: {
@@ -11183,6 +11190,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1118311190
break :non_error .{
1118411191
.body = sema.code.bodySlice(extra_body_start, info.body_len),
1118511192
.end = extra_body_start + info.body_len,
11193+
.capture = info.capture,
1118611194
};
1118711195
};
1118811196

@@ -11207,7 +11215,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1120711215
.body = sema.code.bodySlice(extra_body_start, info.body_len),
1120811216
.end = extra_body_start + info.body_len,
1120911217
.is_inline = info.is_inline,
11210-
.has_capture = info.capture == .by_val,
11218+
.has_capture = info.capture != .none,
1121111219
};
1121211220
};
1121311221

@@ -11217,7 +11225,10 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1121711225
}
1121811226

1121911227
const operand_ty = sema.typeOf(raw_operand_val);
11220-
const operand_err_set_ty = operand_ty.errorUnionSet(mod);
11228+
const operand_err_set_ty = if (extra.data.bits.payload_is_ref)
11229+
operand_ty.childType(mod).errorUnionSet(mod)
11230+
else
11231+
operand_ty.errorUnionSet(mod);
1122111232

1122211233
const block_inst: Air.Inst.Index = @enumFromInt(sema.air_instructions.len);
1122311234
try sema.air_instructions.append(gpa, .{
@@ -11285,7 +11296,12 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1128511296
.tag_capture_inst = undefined,
1128611297
};
1128711298

11288-
if (try sema.resolveDefinedValue(&child_block, src, raw_operand_val)) |operand_val| {
11299+
if (try sema.resolveDefinedValue(&child_block, src, raw_operand_val)) |ov| {
11300+
const operand_val = if (extra.data.bits.payload_is_ref)
11301+
(try sema.pointerDeref(&child_block, src, ov, operand_ty)).?
11302+
else
11303+
ov;
11304+
1128911305
if (operand_val.errorUnionIsPayload(mod)) {
1129011306
return sema.resolveBlockBody(block, operand_src, &child_block, non_error_case.body, inst, merges);
1129111307
} else {
@@ -11295,7 +11311,10 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1129511311
.name = operand_val.getErrorName(mod).unwrap().?,
1129611312
},
1129711313
}));
11298-
spa.operand = try sema.analyzeErrUnionCode(block, operand_src, raw_operand_val);
11314+
spa.operand = if (extra.data.bits.payload_is_ref)
11315+
try sema.analyzeErrUnionCodePtr(block, operand_src, raw_operand_val)
11316+
else
11317+
try sema.analyzeErrUnionCode(block, operand_src, raw_operand_val);
1129911318

1130011319
if (extra.data.bits.any_uses_err_capture) {
1130111320
sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand);
@@ -11339,7 +11358,14 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1133911358
unreachable;
1134011359
}
1134111360

11342-
const cond = try sema.analyzeIsNonErr(block, src, raw_operand_val);
11361+
const cond = if (extra.data.bits.payload_is_ref) blk: {
11362+
try sema.checkErrorType(block, src, sema.typeOf(raw_operand_val).elemType2(mod));
11363+
const loaded = try sema.analyzeLoad(block, src, raw_operand_val, src);
11364+
break :blk try sema.analyzeIsNonErr(block, src, loaded);
11365+
} else blk: {
11366+
try sema.checkErrorType(block, src, sema.typeOf(raw_operand_val));
11367+
break :blk try sema.analyzeIsNonErr(block, src, raw_operand_val);
11368+
};
1134311369

1134411370
var sub_block = child_block.makeSubBlock();
1134511371
sub_block.runtime_loop = null;
@@ -11351,7 +11377,11 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp
1135111377
const true_instructions = try sub_block.instructions.toOwnedSlice(gpa);
1135211378
defer gpa.free(true_instructions);
1135311379

11354-
spa.operand = try sema.analyzeErrUnionCode(&sub_block, operand_src, raw_operand_val);
11380+
spa.operand = if (extra.data.bits.payload_is_ref)
11381+
try sema.analyzeErrUnionCodePtr(&sub_block, operand_src, raw_operand_val)
11382+
else
11383+
try sema.analyzeErrUnionCode(&sub_block, operand_src, raw_operand_val);
11384+
1135511385
if (extra.data.bits.any_uses_err_capture) {
1135611386
sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand);
1135711387
}

src/Zir.zig

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2794,9 +2794,10 @@ pub const Inst = struct {
27942794
/// If true, there is an else prong. This is mutually exclusive with `has_under`.
27952795
has_else: bool,
27962796
any_uses_err_capture: bool,
2797+
payload_is_ref: bool,
27972798
scalar_cases_len: ScalarCasesLen,
27982799

2799-
pub const ScalarCasesLen = u29;
2800+
pub const ScalarCasesLen = u28;
28002801
};
28012802

28022803
pub const MultiProng = struct {

0 commit comments

Comments
 (0)