Skip to content

Commit 29b71ab

Browse files
authored
Merge pull request #87 from justinboswell/ctad
Added support for template deduction guides
2 parents 64c5290 + 88a7048 commit 29b71ab

File tree

5 files changed

+164
-11
lines changed

5 files changed

+164
-11
lines changed

cxxheaderparser/parser.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Concept,
2626
DecltypeSpecifier,
2727
DecoratedType,
28+
DeductionGuide,
2829
EnumDecl,
2930
Enumerator,
3031
Field,
@@ -1868,10 +1869,9 @@ def _parse_parameters(
18681869
_auto_return_typename = PQName([AutoSpecifier()])
18691870

18701871
def _parse_trailing_return_type(
1871-
self, fn: typing.Union[Function, FunctionType]
1872-
) -> None:
1872+
self, return_type: typing.Optional[DecoratedType]
1873+
) -> DecoratedType:
18731874
# entry is "->"
1874-
return_type = fn.return_type
18751875
if not (
18761876
isinstance(return_type, Type)
18771877
and not return_type.const
@@ -1890,8 +1890,7 @@ def _parse_trailing_return_type(
18901890

18911891
dtype = self._parse_cv_ptr(parsed_type)
18921892

1893-
fn.has_trailing_return = True
1894-
fn.return_type = dtype
1893+
return dtype
18951894

18961895
def _parse_fn_end(self, fn: Function) -> None:
18971896
"""
@@ -1918,7 +1917,9 @@ def _parse_fn_end(self, fn: Function) -> None:
19181917
fn.raw_requires = self._parse_requires(rtok)
19191918

19201919
if self.lex.token_if("ARROW"):
1921-
self._parse_trailing_return_type(fn)
1920+
return_type = self._parse_trailing_return_type(fn.return_type)
1921+
fn.has_trailing_return = True
1922+
fn.return_type = return_type
19221923

19231924
if self.lex.token_if("{"):
19241925
self._discard_contents("{", "}")
@@ -1966,7 +1967,9 @@ def _parse_method_end(self, method: Method) -> None:
19661967
elif tok_value in ("&", "&&"):
19671968
method.ref_qualifier = tok_value
19681969
elif tok_value == "->":
1969-
self._parse_trailing_return_type(method)
1970+
return_type = self._parse_trailing_return_type(method.return_type)
1971+
method.has_trailing_return = True
1972+
method.return_type = return_type
19701973
if self.lex.token_if("{"):
19711974
self._discard_contents("{", "}")
19721975
method.has_body = True
@@ -2000,6 +2003,7 @@ def _parse_function(
20002003
is_friend: bool,
20012004
is_typedef: bool,
20022005
msvc_convention: typing.Optional[LexToken],
2006+
is_guide: bool = False,
20032007
) -> bool:
20042008
"""
20052009
Assumes the caller has already consumed the return type and name, this consumes the
@@ -2076,7 +2080,21 @@ def _parse_function(
20762080
self.visitor.on_method_impl(state, method)
20772081

20782082
return method.has_body or method.has_trailing_return
2079-
2083+
elif is_guide:
2084+
assert isinstance(state, (ExternBlockState, NamespaceBlockState))
2085+
if not self.lex.token_if("ARROW"):
2086+
raise self._parse_error(None, expected="Trailing return type")
2087+
return_type = self._parse_trailing_return_type(
2088+
Type(PQName([AutoSpecifier()]))
2089+
)
2090+
guide = DeductionGuide(
2091+
return_type,
2092+
name=pqname,
2093+
parameters=params,
2094+
doxygen=doxygen,
2095+
)
2096+
self.visitor.on_deduction_guide(state, guide)
2097+
return False
20802098
else:
20812099
assert return_type is not None
20822100
fn = Function(
@@ -2210,7 +2228,9 @@ def _parse_cv_ptr_or_fn(
22102228
assert not isinstance(dtype, FunctionType)
22112229
dtype = dtype_fn = FunctionType(dtype, fn_params, vararg)
22122230
if self.lex.token_if("ARROW"):
2213-
self._parse_trailing_return_type(dtype_fn)
2231+
return_type = self._parse_trailing_return_type(dtype_fn.return_type)
2232+
dtype_fn.has_trailing_return = True
2233+
dtype_fn.return_type = return_type
22142234

22152235
else:
22162236
msvc_convention = None
@@ -2391,6 +2411,7 @@ def _parse_decl(
23912411
destructor = False
23922412
op = None
23932413
msvc_convention = None
2414+
is_guide = False
23942415

23952416
# If we have a leading (, that's either an obnoxious grouping
23962417
# paren or it's a constructor
@@ -2441,8 +2462,15 @@ def _parse_decl(
24412462
# grouping paren like "void (name(int x));"
24422463
toks = self._consume_balanced_tokens(tok)
24432464

2444-
# .. not sure what it's grouping, so put it back?
2445-
self.lex.return_tokens(toks[1:-1])
2465+
# check to see if the next token is an arrow, and thus a trailing return
2466+
if self.lex.token_peek_if("ARROW"):
2467+
self.lex.return_tokens(toks)
2468+
# the leading name of the class/ctor has been parsed as a type before the parens
2469+
pqname = parsed_type.typename
2470+
is_guide = True
2471+
else:
2472+
# .. not sure what it's grouping, so put it back?
2473+
self.lex.return_tokens(toks[1:-1])
24462474

24472475
if dtype:
24482476
msvc_convention = self.lex.token_if_val(*self._msvc_conventions)
@@ -2473,6 +2501,7 @@ def _parse_decl(
24732501
is_friend,
24742502
is_typedef,
24752503
msvc_convention,
2504+
is_guide,
24762505
)
24772506
elif msvc_convention:
24782507
raise self._parse_error(msvc_convention)

cxxheaderparser/simple.py

+9
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .types import (
3636
ClassDecl,
3737
Concept,
38+
DeductionGuide,
3839
EnumDecl,
3940
Field,
4041
ForwardDecl,
@@ -123,6 +124,9 @@ class NamespaceScope:
123124
#: Child namespaces
124125
namespaces: typing.Dict[str, "NamespaceScope"] = field(default_factory=dict)
125126

127+
#: Deduction guides
128+
deduction_guides: typing.List[DeductionGuide] = field(default_factory=list)
129+
126130

127131
Block = typing.Union[ClassScope, NamespaceScope]
128132

@@ -317,6 +321,11 @@ def on_class_friend(self, state: SClassBlockState, friend: FriendDecl) -> None:
317321
def on_class_end(self, state: SClassBlockState) -> None:
318322
pass
319323

324+
def on_deduction_guide(
325+
self, state: SNonClassBlockState, guide: DeductionGuide
326+
) -> None:
327+
state.user_data.deduction_guides.append(guide)
328+
320329

321330
def parse_string(
322331
content: str,

cxxheaderparser/types.py

+18
Original file line numberDiff line numberDiff line change
@@ -896,3 +896,21 @@ class UsingAlias:
896896

897897
#: Documentation if present
898898
doxygen: typing.Optional[str] = None
899+
900+
901+
@dataclass
902+
class DeductionGuide:
903+
"""
904+
.. code-block:: c++
905+
906+
template <class T>
907+
MyClass(T) -> MyClass(int);
908+
"""
909+
910+
#: Only constructors and destructors don't have a return type
911+
result_type: typing.Optional[DecoratedType]
912+
913+
name: PQName
914+
parameters: typing.List[Parameter]
915+
916+
doxygen: typing.Optional[str] = None

cxxheaderparser/visitor.py

+13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .types import (
1111
Concept,
12+
DeductionGuide,
1213
EnumDecl,
1314
Field,
1415
ForwardDecl,
@@ -236,6 +237,13 @@ def on_class_end(self, state: ClassBlockState) -> None:
236237
``on_variable`` for each instance declared.
237238
"""
238239

240+
def on_deduction_guide(
241+
self, state: NonClassBlockState, guide: DeductionGuide
242+
) -> None:
243+
"""
244+
Called when a deduction guide is encountered
245+
"""
246+
239247

240248
class NullVisitor:
241249
"""
@@ -318,5 +326,10 @@ def on_class_method(self, state: ClassBlockState, method: Method) -> None:
318326
def on_class_end(self, state: ClassBlockState) -> None:
319327
return None
320328

329+
def on_deduction_guide(
330+
self, state: NonClassBlockState, guide: DeductionGuide
331+
) -> None:
332+
return None
333+
321334

322335
null_visitor = NullVisitor()

tests/test_template.py

+84
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
BaseClass,
66
ClassDecl,
77
DecltypeSpecifier,
8+
DeductionGuide,
89
Field,
910
ForwardDecl,
1011
Function,
@@ -2163,3 +2164,86 @@ def test_member_class_template_specialization() -> None:
21632164
]
21642165
)
21652166
)
2167+
2168+
2169+
def test_template_deduction_guide() -> None:
2170+
content = """
2171+
template <class CharT, class Traits = std::char_traits<CharT>>
2172+
Error(std::basic_string_view<CharT, Traits>) -> Error<std::string>;
2173+
"""
2174+
data = parse_string(content, cleandoc=True)
2175+
2176+
assert data == ParsedData(
2177+
namespace=NamespaceScope(
2178+
deduction_guides=[
2179+
DeductionGuide(
2180+
result_type=Type(
2181+
typename=PQName(
2182+
segments=[
2183+
NameSpecifier(
2184+
name="Error",
2185+
specialization=TemplateSpecialization(
2186+
args=[
2187+
TemplateArgument(
2188+
arg=Type(
2189+
typename=PQName(
2190+
segments=[
2191+
NameSpecifier(name="std"),
2192+
NameSpecifier(
2193+
name="string"
2194+
),
2195+
]
2196+
)
2197+
)
2198+
)
2199+
]
2200+
),
2201+
)
2202+
]
2203+
)
2204+
),
2205+
name=PQName(segments=[NameSpecifier(name="Error")]),
2206+
parameters=[
2207+
Parameter(
2208+
type=Type(
2209+
typename=PQName(
2210+
segments=[
2211+
NameSpecifier(name="std"),
2212+
NameSpecifier(
2213+
name="basic_string_view",
2214+
specialization=TemplateSpecialization(
2215+
args=[
2216+
TemplateArgument(
2217+
arg=Type(
2218+
typename=PQName(
2219+
segments=[
2220+
NameSpecifier(
2221+
name="CharT"
2222+
)
2223+
]
2224+
)
2225+
)
2226+
),
2227+
TemplateArgument(
2228+
arg=Type(
2229+
typename=PQName(
2230+
segments=[
2231+
NameSpecifier(
2232+
name="Traits"
2233+
)
2234+
]
2235+
)
2236+
)
2237+
),
2238+
]
2239+
),
2240+
),
2241+
]
2242+
)
2243+
)
2244+
)
2245+
],
2246+
)
2247+
]
2248+
)
2249+
)

0 commit comments

Comments
 (0)