Skip to content

Commit a1109e7

Browse files
committed
feat(python_ffi): 支持向 Stream 传入数据
Signed-off-by: YdrMaster <[email protected]>
1 parent 38d6aa9 commit a1109e7

File tree

9 files changed

+91
-24
lines changed

9 files changed

+91
-24
lines changed

src/00common/include/common/fp16_t.h

-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ namespace refactor {
9999
}
100100
};
101101

102-
103102
inline const fp16_t fp16_t::ZERO = fp16_t(0.0f);
104103
inline const fp16_t fp16_t::ONE = fp16_t(1.0f);
105104
inline const fp16_t fp16_t::INF = fp16_t((uint16_t) 0b0'11111'0000000000);

src/03runtime/include/runtime/stream.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace refactor::runtime {
2121
bool isBlob() const noexcept;
2222
bool isOffset() const noexcept;
2323

24-
size_t getOffset() const;
24+
auto blob() const noexcept -> mem_manager::SharedForeignBlob const &;
25+
auto offset() const noexcept -> size_t;
2526
};
2627

2728
class Stream {
@@ -31,14 +32,19 @@ namespace refactor::runtime {
3132

3233
Resources _resources;
3334
mem_manager::SharedForeignBlob _stack;
35+
std::vector<size_t> _outputsSize;
3436
_G _internal;
3537

3638
public:
3739
Stream(Resources,
38-
mem_manager::SharedForeignBlob,
40+
size_t stack,
41+
std::vector<size_t> outputs,
3942
graph_topo::GraphTopo,
4043
std::vector<_N>,
4144
std::vector<_E>);
45+
void setInput(uint_lv1, void const *, size_t);
46+
void setInput(uint_lv1, mem_manager::SharedForeignBlob);
47+
std::vector<uint_lv1> prepare();
4248
void run();
4349
};
4450

src/03runtime/src/stream.cc

+46-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "runtime/stream.h"
2+
#include "runtime/mem_manager.hh"
23

34
namespace refactor::runtime {
5+
using mem_manager::ForeignBlob;
46

57
void emptyRoutine(runtime::Resources &, void const **, void **) {}
68

@@ -15,23 +17,62 @@ namespace refactor::runtime {
1517
bool Address::isOffset() const noexcept {
1618
return std::holds_alternative<size_t>(value);
1719
}
18-
19-
size_t Address::getOffset() const {
20+
auto Address::blob() const noexcept -> mem_manager::SharedForeignBlob const & {
21+
return std::get<mem_manager::SharedForeignBlob>(value);
22+
}
23+
auto Address::offset() const noexcept -> size_t {
2024
return std::get<size_t>(value);
2125
}
2226

2327
Stream::Stream(Resources resources,
24-
mem_manager::SharedForeignBlob stack,
28+
size_t stack,
29+
std::vector<size_t> outputs,
2530
graph_topo::GraphTopo topology,
2631
std::vector<_N> routines,
2732
std::vector<_E> offsets)
2833
: _resources(std::move(resources)),
29-
_stack(std::move(stack)),
34+
_stack(ForeignBlob::share(_resources.fetch<MemManager>()->manager, stack)),
35+
_outputsSize(std::move(outputs)),
3036
_internal(_G{
3137
std::move(topology),
3238
std::move(routines),
3339
std::move(offsets),
34-
}) {}
40+
}) {
41+
}
42+
43+
void Stream::setInput(uint_lv1 i, void const *data, size_t size) {
44+
auto globalInputs = _internal.topology.globalInputs();
45+
ASSERT(i < globalInputs.size(), "input index out of range");
46+
47+
auto allocator = _resources.fetch<MemManager>()->manager;
48+
auto blob = ForeignBlob::share(std::move(allocator), size);
49+
blob->copyIn(data, size);
50+
_internal.edges[globalInputs[i]].value = {std::move(blob)};
51+
}
52+
void Stream::setInput(uint_lv1 i, mem_manager::SharedForeignBlob blob) {
53+
auto globalInputs = _internal.topology.globalInputs();
54+
ASSERT(i < globalInputs.size(), "input index out of range");
55+
56+
_internal.edges[globalInputs[i]].value = {std::move(blob)};
57+
}
58+
59+
std::vector<uint_lv1> Stream::prepare() {
60+
auto globalInputs = _internal.topology.globalInputs();
61+
std::vector<uint_lv1> unknownInputs;
62+
for (auto i : range0_(globalInputs.size())) {
63+
if (!_internal.edges[globalInputs[i]].blob()) {
64+
unknownInputs.push_back(i);
65+
}
66+
}
67+
if (unknownInputs.empty()) {
68+
auto allocator = _resources.fetch<MemManager>()->manager;
69+
auto outputs = _internal.topology.globalOutputs();
70+
for (auto i : range0_(outputs.size())) {
71+
_internal.edges[outputs[i]].value = {ForeignBlob::share(allocator, _outputsSize[i])};
72+
}
73+
}
74+
return unknownInputs;
75+
}
3576

3677
void Stream::run() {
3778
auto map = [this](auto i) { return _internal.edges[i](*_stack); };

src/04kernel/src/allocators/reusable_allocator.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace refactor::kernel {
2727
if (!--edgeRc[inputIdx]) {
2828
// indicate that this tensor will no longer be used and perform memory free
2929
if (addresses[inputIdx].isOffset()) {
30-
calculator.free(addresses[inputIdx].getOffset(), g.edges[inputIdx].size);
30+
calculator.free(addresses[inputIdx].offset(), g.edges[inputIdx].size);
3131
}
3232
}
3333
}

src/04kernel/src/graph.cc

+9-4
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ namespace refactor::kernel {
2424
? node.kernel->lower()
2525
: refactor::runtime::emptyRoutine;
2626
});
27-
auto [size, offsets] = allocator(_internal, sizeof(uint64_t));
28-
auto memManager = _target.memManager();
27+
auto [stack, offsets] = allocator(_internal, sizeof(uint64_t));
28+
auto outputs = _internal.topology.globalOutputs();
29+
std::vector<size_t> outputs_(outputs.size());
30+
std::transform(outputs.begin(), outputs.end(),
31+
outputs_.begin(),
32+
[this](auto const &edge) { return _internal.edges[edge].size; });
2933
runtime::Resources res;
30-
res.fetchOrStore<runtime::MemManager>(memManager);
34+
res.fetchOrStore<runtime::MemManager>(_target.memManager());
3135
return runtime::Stream(
3236
std::move(res),
33-
mem_manager::ForeignBlob::share(std::move(memManager), size),
37+
stack,
38+
std::move(outputs_),
3439
_internal.topology,
3540
std::move(routines),
3641
std::move(offsets));

src/09python_ffi/src/compiler.cc

+5-6
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,11 @@ namespace refactor::python_ffi {
6868
UNREACHABLE();
6969
}
7070

71-
auto kernel = computation.lower(target_);
72-
if (allocator == "flat") {
73-
return std::make_shared<Executor>(kernel.lower(kernel::flatAllocate));
74-
} else {
75-
return std::make_shared<Executor>(kernel.lower(kernel::reusableAllocate));
76-
}
71+
return std::make_shared<Executor>(
72+
computation.lower(target_),
73+
allocator == "flat"
74+
? kernel::flatAllocate
75+
: kernel::reusableAllocate);
7776
}
7877

7978
std::optional<py::array>

src/09python_ffi/src/executor.cc

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22

33
namespace refactor::python_ffi {
44

5-
Executor::Executor(runtime::Stream stream)
6-
: _stream(std::move(stream)) {
5+
Executor::Executor(kernel::Graph graph, kernel::Allocator allocator)
6+
: _graph(std::move(graph)),
7+
_allocator(allocator),
8+
_stream(_graph.lower(_allocator)) {}
9+
10+
void Executor::setInput(uint_lv1 i, SharedTensor tensor) {
11+
_stream.setInput(i, tensor->data->operator const void *(), tensor->bytesSize());
12+
}
13+
14+
std::vector<uint_lv1> Executor::prepare() {
15+
return _stream.prepare();
716
}
817

918
}// namespace refactor::python_ffi

src/09python_ffi/src/executor.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
#ifndef PYTHON_FFI_EXECUTOR_H
22
#define PYTHON_FFI_EXECUTOR_H
33

4-
#include "runtime/stream.h"
4+
#include "frontend/tensor.h"
5+
#include "kernel/graph.h"
56

67
namespace refactor::python_ffi {
8+
using SharedTensor = Arc<frontend::Tensor>;
79

810
class Executor {
11+
kernel::Graph _graph;
12+
kernel::Allocator _allocator;
913
runtime::Stream _stream;
1014

1115
public:
12-
explicit Executor(runtime::Stream);
16+
Executor(kernel::Graph, kernel::Allocator);
17+
void setInput(uint_lv1, SharedTensor);
18+
std::vector<uint_lv1> prepare();
1319
};
1420

1521
}// namespace refactor::python_ffi

src/09python_ffi/src/main.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ namespace refactor::python_ffi {
3333
.def("get_tensor" , &Compiler::getTensor , return_::move )
3434
.def("compile" , &Compiler::compile , return_::move );
3535

36-
py::class_<Executor , Arc<Executor>>(m, "Executor" );
36+
py::class_<Executor , Arc<Executor>>(m, "Executor" )
37+
.def("setInput" , &Executor::setInput , return_::automatic )
38+
.def("prepare" , &Executor::prepare , return_::move );
3739

3840
// clang-format on
3941
}

0 commit comments

Comments
 (0)