Skip to content

Commit 63c18e5

Browse files
committed
[UR] Replace loader handles with field at start of handle data
All handles from all backends are now required to implement `ddi_getter` and their first field must be a pointer to a `ur_ddi_table_t` (which also implies that they must not have a vtable). Instead of wrapping handles in a special wrapper type, we instead query the DDI table stored in the handle itself. This simplifies the loader greatly.
1 parent 1cb6d33 commit 63c18e5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1385
-6220
lines changed

unified-runtime/scripts/generate_code.py

-19
Original file line numberDiff line numberDiff line change
@@ -223,25 +223,6 @@ def _mako_loader_cpp(path, namespace, tags, version, specs, meta):
223223
"make_loader_cpp path %s namespace %s version %s\n" % (path, namespace, version)
224224
)
225225
loc = 0
226-
template = "ldrddi.hpp.mako"
227-
fin = os.path.join(templates_dir, template)
228-
229-
name = "%s_ldrddi" % (namespace)
230-
filename = "%s.hpp" % (name)
231-
fout = os.path.join(path, filename)
232-
233-
print("Generating %s..." % fout)
234-
loc += util.makoWrite(
235-
fin,
236-
fout,
237-
name=name,
238-
ver=version,
239-
namespace=namespace,
240-
tags=tags,
241-
specs=specs,
242-
meta=meta,
243-
)
244-
245226
template = "ldrddi.cpp.mako"
246227
fin = os.path.join(templates_dir, template)
247228

unified-runtime/scripts/templates/helper.py

+31-189
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def get_adapter_manifests(specs):
761761
objs.append(obj)
762762
return objs
763763

764+
764765
"""
765766
Public:
766767
returns a list of all loader API functions' names
@@ -1510,39 +1511,6 @@ def get_initial_null_set(obj):
15101511
return ""
15111512

15121513

1513-
"""
1514-
Public:
1515-
returns true if the function always wraps output pointers in loader handles
1516-
"""
1517-
1518-
1519-
def always_wrap_outputs(obj):
1520-
cname = obj_traits.class_name(obj)
1521-
return (cname, obj["name"]) in [
1522-
("$xProgram", "Link"),
1523-
("$xProgram", "LinkExp"),
1524-
]
1525-
1526-
1527-
"""
1528-
Private:
1529-
returns the list of parameters, filtering based on desc tags
1530-
"""
1531-
1532-
1533-
def _filter_param_list(params, filters1=["[in]", "[in,out]", "[out]"], filters2=[""]):
1534-
lst = []
1535-
for p in params:
1536-
for f1 in filters1:
1537-
if f1 in p["desc"]:
1538-
for f2 in filters2:
1539-
if f2 in p["desc"]:
1540-
lst.append(p)
1541-
break
1542-
break
1543-
return lst
1544-
1545-
15461514
"""
15471515
Public:
15481516
returns a list of dict of each pfntables needed
@@ -1560,131 +1528,6 @@ def get_pfncbtables(specs, meta, namespace, tags):
15601528
return tables
15611529

15621530

1563-
"""
1564-
Public:
1565-
returns a list of dict for converting loader input parameters
1566-
"""
1567-
1568-
1569-
def get_loader_prologue(namespace, tags, obj, meta):
1570-
prologue = []
1571-
1572-
params = _filter_param_list(obj["params"], ["[in]"])
1573-
for item in params:
1574-
if param_traits.is_mbz(item):
1575-
continue
1576-
if type_traits.is_class_handle(item["type"], meta):
1577-
name = subt(namespace, tags, item["name"])
1578-
tname = _remove_const_ptr(subt(namespace, tags, item["type"]))
1579-
1580-
# e.g., "xe_device_handle_t" -> "xe_device_object_t"
1581-
obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname)
1582-
fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname)
1583-
1584-
if type_traits.is_pointer(item["type"]):
1585-
range_start = param_traits.range_start(item)
1586-
range_end = param_traits.range_end(item)
1587-
prologue.append(
1588-
{
1589-
"name": name,
1590-
"obj": obj_name,
1591-
"range": (range_start, range_end),
1592-
"type": tname,
1593-
"factory": fty_name,
1594-
"pointer": "*",
1595-
}
1596-
)
1597-
else:
1598-
prologue.append(
1599-
{
1600-
"name": name,
1601-
"obj": obj_name,
1602-
"optional": param_traits.is_optional(item),
1603-
"pointer": "",
1604-
}
1605-
)
1606-
1607-
return prologue
1608-
1609-
1610-
"""
1611-
Private:
1612-
Takes a list of struct members and recursively searches for class handles.
1613-
Returns a list of class handles with access chains to reach them (e.g.
1614-
"struct_a->struct_b.handle"). Also handles ranges of class handles and
1615-
ranges of structs with class handle members, although the latter only works
1616-
to one level of recursion i.e. a range of structs with a range of structs
1617-
with a handle member will not work.
1618-
"""
1619-
1620-
1621-
def get_struct_handle_members(
1622-
namespace, tags, meta, members, parent="", is_struct_range=False
1623-
):
1624-
handle_members = []
1625-
for m in members:
1626-
if type_traits.is_class_handle(m["type"], meta):
1627-
m_tname = _remove_const_ptr(subt(namespace, tags, m["type"]))
1628-
m_objname = re.sub(r"(\w+)_handle_t", r"\1_object_t", m_tname)
1629-
# We can deal with a range of handles, but not if it's in a range of structs
1630-
if param_traits.is_range(m) and not is_struct_range:
1631-
handle_members.append(
1632-
{
1633-
"parent": parent,
1634-
"name": m["name"],
1635-
"obj_name": m_objname,
1636-
"type": m_tname,
1637-
"range_start": param_traits.range_start(m),
1638-
"range_end": param_traits.range_end(m),
1639-
}
1640-
)
1641-
else:
1642-
handle_members.append(
1643-
{
1644-
"parent": parent,
1645-
"name": m["name"],
1646-
"obj_name": m_objname,
1647-
"optional": param_traits.is_optional(m),
1648-
}
1649-
)
1650-
elif type_traits.is_struct(m["type"], meta):
1651-
member_struct_members = type_traits.get_struct_members(m["type"], meta)
1652-
if param_traits.is_range(m):
1653-
# If we've hit a range of structs we need to start a new recursion looking
1654-
# for handle members. We do not support range within range, so skip that
1655-
if is_struct_range:
1656-
continue
1657-
range_handle_members = get_struct_handle_members(
1658-
namespace, tags, meta, member_struct_members, "", True
1659-
)
1660-
if range_handle_members:
1661-
handle_members.append(
1662-
{
1663-
"parent": parent,
1664-
"name": m["name"],
1665-
"type": subt(namespace, tags, _remove_const_ptr(m["type"])),
1666-
"range_start": param_traits.range_start(m),
1667-
"range_end": param_traits.range_end(m),
1668-
"handle_members": range_handle_members,
1669-
}
1670-
)
1671-
else:
1672-
# If it's just a struct we can keep recursing in search of handles
1673-
m_is_pointer = type_traits.is_pointer(m["type"])
1674-
new_parent_deref = "->" if m_is_pointer else "."
1675-
new_parent = m["name"] + new_parent_deref
1676-
handle_members += get_struct_handle_members(
1677-
namespace,
1678-
tags,
1679-
meta,
1680-
member_struct_members,
1681-
new_parent,
1682-
is_struct_range,
1683-
)
1684-
1685-
return handle_members
1686-
1687-
16881531
"""
16891532
Public:
16901533
Strips a string of all dereferences.
@@ -1702,37 +1545,6 @@ def strip_deref(string_to_strip):
17021545
return string_to_strip.replace("->", "")
17031546

17041547

1705-
"""
1706-
Public:
1707-
Takes a function object and recurses through its struct parameters to return
1708-
a list of structs that have handle object members the loader will need to
1709-
convert.
1710-
"""
1711-
1712-
1713-
def get_object_handle_structs_to_convert(namespace, tags, obj, meta):
1714-
structs = []
1715-
params = _filter_param_list(obj["params"], ["[in]"])
1716-
1717-
for item in params:
1718-
if type_traits.is_struct(item["type"], meta):
1719-
members = type_traits.get_struct_members(item["type"], meta)
1720-
handle_members = get_struct_handle_members(namespace, tags, meta, members)
1721-
if handle_members:
1722-
name = subt(namespace, tags, item["name"])
1723-
tname = _remove_const_ptr(subt(namespace, tags, item["type"]))
1724-
struct = {
1725-
"name": name,
1726-
"type": tname,
1727-
"optional": param_traits.is_optional(item),
1728-
"members": handle_members,
1729-
}
1730-
1731-
structs.append(struct)
1732-
1733-
return structs
1734-
1735-
17361548
"""
17371549
Public:
17381550
returns an enum object with the given name
@@ -2039,3 +1851,33 @@ def get_etors(obj):
20391851
if etor_traits.is_deprecated_etor(item):
20401852
continue
20411853
yield item
1854+
1855+
1856+
"""
1857+
Public:
1858+
Returns the first non-optional non-native handle for the given function.
1859+
1860+
If it is a range, `name[0]` will be returned instead of `name`.
1861+
"""
1862+
1863+
1864+
def get_dditable_field(obj):
1865+
for p in obj["params"]:
1866+
if param_traits.is_optional(p):
1867+
continue
1868+
if "native_handle_t" in p["type"]:
1869+
continue
1870+
1871+
if param_traits.is_range(p):
1872+
if not p["type"].endswith("_handle_t*"):
1873+
continue
1874+
return p["name"] + "[0]"
1875+
else:
1876+
if not p["type"].endswith("_handle_t"):
1877+
continue
1878+
return p["name"]
1879+
obj_class = obj["class"]
1880+
name = obj["name"]
1881+
raise RuntimeError(
1882+
f"Function {obj_class}::{name} does not have a non-optional handle argument"
1883+
)

0 commit comments

Comments
 (0)