Skip to content

Commit 9716a1c

Browse files
ehaasVexu
authored andcommitted
translate-c: Add support for cast-to-union
Fixes #10955
1 parent 4a0b037 commit 9716a1c

8 files changed

+95
-6
lines changed

lib/std/zig/c_translation.zig

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ pub fn cast(comptime DestType: type, target: anytype) DestType {
3030
else => {},
3131
}
3232
},
33+
.Union => |info| {
34+
inline for (info.fields) |field| {
35+
if (field.field_type == SourceType) return @unionInit(DestType, field.name, target);
36+
}
37+
@compileError("cast to union type '" ++ @typeName(DestType) ++ "' from type '" ++ @typeName(SourceType) ++ "' which is not present in union");
38+
},
3339
else => {},
3440
}
3541
return @as(DestType, target);

src/clang.zig

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,14 @@ pub const CaseStmt = opaque {
258258
extern fn ZigClangCaseStmt_getSubStmt(*const CaseStmt) *const Stmt;
259259
};
260260

261+
pub const CastExpr = opaque {
262+
pub const getCastKind = ZigClangCastExpr_getCastKind;
263+
extern fn ZigClangCastExpr_getCastKind(*const CastExpr) CK;
264+
265+
pub const getTargetFieldForToUnionCast = ZigClangCastExpr_getTargetFieldForToUnionCast;
266+
extern fn ZigClangCastExpr_getTargetFieldForToUnionCast(*const CastExpr, QualType, QualType) ?*const FieldDecl;
267+
};
268+
261269
pub const CharacterLiteral = opaque {
262270
pub const getBeginLoc = ZigClangCharacterLiteral_getBeginLoc;
263271
extern fn ZigClangCharacterLiteral_getBeginLoc(*const CharacterLiteral) SourceLocation;

src/translate_c.zig

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,14 +1791,31 @@ fn transCStyleCastExprClass(
17911791
stmt: *const clang.CStyleCastExpr,
17921792
result_used: ResultUsed,
17931793
) TransError!Node {
1794+
const cast_expr = @ptrCast(*const clang.CastExpr, stmt);
17941795
const sub_expr = stmt.getSubExpr();
1795-
const cast_node = (try transCCast(
1796+
const dst_type = stmt.getType();
1797+
const src_type = sub_expr.getType();
1798+
const sub_expr_node = try transExpr(c, scope, sub_expr, .used);
1799+
const loc = stmt.getBeginLoc();
1800+
1801+
const cast_node = if (cast_expr.getCastKind() == .ToUnion) blk: {
1802+
const field_decl = cast_expr.getTargetFieldForToUnionCast(dst_type, src_type).?; // C syntax error if target field is null
1803+
const field_name = try c.str(@ptrCast(*const clang.NamedDecl, field_decl).getName_bytes_begin());
1804+
1805+
const union_ty = try transQualType(c, scope, dst_type, loc);
1806+
1807+
const inits = [1]ast.Payload.ContainerInit.Initializer{.{ .name = field_name, .value = sub_expr_node }};
1808+
break :blk try Tag.container_init.create(c.arena, .{
1809+
.lhs = union_ty,
1810+
.inits = try c.arena.dupe(ast.Payload.ContainerInit.Initializer, &inits),
1811+
});
1812+
} else (try transCCast(
17961813
c,
17971814
scope,
1798-
stmt.getBeginLoc(),
1799-
stmt.getType(),
1800-
sub_expr.getType(),
1801-
try transExpr(c, scope, sub_expr, .used),
1815+
loc,
1816+
dst_type,
1817+
src_type,
1818+
sub_expr_node,
18021819
));
18031820
return maybeSuppressResult(c, scope, result_used, cast_node);
18041821
}
@@ -2370,7 +2387,7 @@ fn cIntTypeForEnum(enum_qt: clang.QualType) clang.QualType {
23702387
return enum_decl.getIntegerType();
23712388
}
23722389

2373-
// when modifying this function, make sure to also update std.meta.cast
2390+
// when modifying this function, make sure to also update std.zig.c_translation.cast
23742391
fn transCCast(
23752392
c: *Context,
23762393
scope: *Scope,

src/zig_clang.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2986,6 +2986,18 @@ const struct ZigClangCompoundStmt *ZigClangStmtExpr_getSubStmt(const struct ZigC
29862986
return reinterpret_cast<const ZigClangCompoundStmt *>(casted->getSubStmt());
29872987
}
29882988

2989+
enum ZigClangCK ZigClangCastExpr_getCastKind(const struct ZigClangCastExpr *self) {
2990+
auto casted = reinterpret_cast<const clang::CastExpr *>(self);
2991+
return (ZigClangCK)casted->getCastKind();
2992+
}
2993+
2994+
const struct ZigClangFieldDecl *ZigClangCastExpr_getTargetFieldForToUnionCast(const struct ZigClangCastExpr *self, ZigClangQualType union_type, ZigClangQualType op_type) {
2995+
clang::QualType union_qt = bitcast(union_type);
2996+
clang::QualType op_qt = bitcast(op_type);
2997+
auto casted = reinterpret_cast<const clang::CastExpr *>(self);
2998+
return reinterpret_cast<const ZigClangFieldDecl *>(casted->getTargetFieldForToUnionCast(union_qt, op_qt));
2999+
}
3000+
29893001
struct ZigClangSourceLocation ZigClangCharacterLiteral_getBeginLoc(const struct ZigClangCharacterLiteral *self) {
29903002
auto casted = reinterpret_cast<const clang::CharacterLiteral *>(self);
29913003
return bitcast(casted->getBeginLoc());

src/zig_clang.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ struct ZigClangBuiltinType;
103103
struct ZigClangCStyleCastExpr;
104104
struct ZigClangCallExpr;
105105
struct ZigClangCaseStmt;
106+
struct ZigClangCastExpr;
106107
struct ZigClangCharacterLiteral;
107108
struct ZigClangChooseExpr;
108109
struct ZigClangCompoundAssignOperator;
@@ -1317,6 +1318,9 @@ ZIG_EXTERN_C struct ZigClangQualType ZigClangDecayedType_getDecayedType(const st
13171318

13181319
ZIG_EXTERN_C const struct ZigClangCompoundStmt *ZigClangStmtExpr_getSubStmt(const struct ZigClangStmtExpr *);
13191320

1321+
ZIG_EXTERN_C enum ZigClangCK ZigClangCastExpr_getCastKind(const struct ZigClangCastExpr *);
1322+
ZIG_EXTERN_C const struct ZigClangFieldDecl *ZigClangCastExpr_getTargetFieldForToUnionCast(const struct ZigClangCastExpr *, struct ZigClangQualType, struct ZigClangQualType);
1323+
13201324
ZIG_EXTERN_C struct ZigClangSourceLocation ZigClangCharacterLiteral_getBeginLoc(const struct ZigClangCharacterLiteral *);
13211325
ZIG_EXTERN_C enum ZigClangCharacterLiteral_CharacterKind ZigClangCharacterLiteral_getKind(const struct ZigClangCharacterLiteral *);
13221326
ZIG_EXTERN_C unsigned ZigClangCharacterLiteral_getValue(const struct ZigClangCharacterLiteral *);

test/behavior/translate_c_macros.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ struct Foo {
1515
int a;
1616
};
1717

18+
union U {
19+
long l;
20+
double d;
21+
};
22+
1823
#define SIZE_OF_FOO sizeof(struct Foo)
1924

2025
#define MAP_FAILED ((void *) -1)
@@ -30,3 +35,5 @@ struct Foo {
3035
#define IGNORE_ME_8(x) (volatile void)(x)
3136
#define IGNORE_ME_9(x) (const volatile void)(x)
3237
#define IGNORE_ME_10(x) (volatile const void)(x)
38+
39+
#define UNION_CAST(X) (union U)(X)

test/behavior/translate_c_macros.zig

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,16 @@ test "cast negative integer to pointer" {
4747

4848
try expectEqual(@intToPtr(?*anyopaque, @bitCast(usize, @as(isize, -1))), h.MAP_FAILED);
4949
}
50+
51+
test "casting to union with a macro" {
52+
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO Sema.zirUnionInitPtr
53+
54+
const l: c_long = 42;
55+
const d: f64 = 2.0;
56+
57+
var casted = h.UNION_CAST(l);
58+
try expectEqual(l, casted.l);
59+
60+
casted = h.UNION_CAST(d);
61+
try expectEqual(d, casted.d);
62+
}

test/run_translated_c.zig

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,4 +1829,26 @@ pub fn addCases(cases: *tests.RunTranslatedCContext) void {
18291829
\\ return 0;
18301830
\\}
18311831
, "");
1832+
1833+
cases.add("Cast-to-union. Issue #10955",
1834+
\\#include <stdlib.h>
1835+
\\struct S { int x; };
1836+
\\union U {
1837+
\\ long l;
1838+
\\ double d;
1839+
\\ struct S s;
1840+
\\};
1841+
\\union U bar(union U u) { return u; }
1842+
\\int main(void) {
1843+
\\ union U u = (union U) 42L;
1844+
\\ if (u.l != 42L) abort();
1845+
\\ u = (union U) 2.0;
1846+
\\ if (u.d != 2.0) abort();
1847+
\\ u = bar((union U)4.0);
1848+
\\ if (u.d != 4.0) abort();
1849+
\\ u = (union U)(struct S){ .x = 5 };
1850+
\\ if (u.s.x != 5) abort();
1851+
\\ return 0;
1852+
\\}
1853+
, "");
18321854
}

0 commit comments

Comments
 (0)