Skip to content

Commit 3dc5db3

Browse files
committed
complete homework 1
1 parent 411714a commit 3dc5db3

6 files changed

Lines changed: 272 additions & 2 deletions

File tree

include/core/allocator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ namespace infini {
2626
// =================================== 作业 ===================================
2727
// TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并
2828
// HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小
29+
// Free block management:
30+
// - freeBlocksByAddr: map from start address to size, for allocation search
31+
// - freeBlocksByEnd: map from end address to size, for merging adjacent blocks
32+
std::map<size_t, size_t> freeBlocksByAddr; // key: start address, value: size
33+
std::map<size_t, size_t> freeBlocksByEnd; // key: end address, value: size
2934
// =================================== 作业 ===================================
3035

3136
public:

src/core/allocator.cc

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,45 @@ namespace infini
3131

3232
// =================================== 作业 ===================================
3333
// TODO: 设计一个算法来分配内存,返回起始地址偏移量
34+
// Use First Fit algorithm: find the first free block that is large enough
35+
for (auto it = freeBlocksByAddr.begin(); it != freeBlocksByAddr.end(); ++it)
36+
{
37+
size_t addr = it->first;
38+
size_t blockSize = it->second;
39+
40+
if (blockSize >= size)
41+
{
42+
// Found a suitable block
43+
// Remove this block from both maps
44+
freeBlocksByAddr.erase(it);
45+
freeBlocksByEnd.erase(addr + blockSize);
46+
47+
// If the block is larger than needed, add the remaining part back
48+
if (blockSize > size)
49+
{
50+
size_t newAddr = addr + size;
51+
size_t newSize = blockSize - size;
52+
freeBlocksByAddr[newAddr] = newSize;
53+
freeBlocksByEnd[newAddr + newSize] = newSize;
54+
}
55+
56+
// Update memory usage statistics
57+
used += size;
58+
if (used > peak)
59+
{
60+
peak = used;
61+
}
62+
63+
return addr;
64+
}
65+
}
66+
67+
// No suitable free block found, allocate at the end
68+
size_t addr = peak;
69+
used += size;
70+
peak += size;
71+
72+
return addr;
3473
// =================================== 作业 ===================================
3574

3675
return 0;
@@ -43,6 +82,75 @@ namespace infini
4382

4483
// =================================== 作业 ===================================
4584
// TODO: 设计一个算法来回收内存
85+
// Update memory usage
86+
used -= size;
87+
88+
size_t blockStart = addr;
89+
size_t blockEnd = addr + size;
90+
size_t blockSize = size;
91+
92+
// Special case: if freeing the block at the end, just reduce peak
93+
if (blockEnd == peak)
94+
{
95+
// Check if we can merge with a previous free block at the end
96+
auto prevIt = freeBlocksByEnd.find(blockStart);
97+
if (prevIt != freeBlocksByEnd.end())
98+
{
99+
// Merge with the previous block and reduce peak further
100+
size_t prevSize = prevIt->second;
101+
size_t prevStart = blockStart - prevSize;
102+
103+
// Remove the previous block from both maps
104+
freeBlocksByAddr.erase(prevStart);
105+
freeBlocksByEnd.erase(blockStart);
106+
107+
// Reduce peak to the start of the merged block
108+
peak = prevStart;
109+
}
110+
else
111+
{
112+
// Just reduce peak
113+
peak = blockStart;
114+
}
115+
return;
116+
}
117+
118+
// Try to merge with the previous adjacent free block
119+
auto prevIt = freeBlocksByEnd.find(blockStart);
120+
if (prevIt != freeBlocksByEnd.end())
121+
{
122+
// Found a previous adjacent block
123+
size_t prevSize = prevIt->second;
124+
size_t prevStart = blockStart - prevSize;
125+
126+
// Remove the previous block from both maps
127+
freeBlocksByAddr.erase(prevStart);
128+
freeBlocksByEnd.erase(blockStart);
129+
130+
// Merge: extend the current block backwards
131+
blockStart = prevStart;
132+
blockSize += prevSize;
133+
}
134+
135+
// Try to merge with the next adjacent free block
136+
auto nextIt = freeBlocksByAddr.find(blockEnd);
137+
if (nextIt != freeBlocksByAddr.end())
138+
{
139+
// Found a next adjacent block
140+
size_t nextSize = nextIt->second;
141+
142+
// Remove the next block from both maps
143+
freeBlocksByAddr.erase(blockEnd);
144+
freeBlocksByEnd.erase(blockEnd + nextSize);
145+
146+
// Merge: extend the current block forwards
147+
blockSize += nextSize;
148+
blockEnd += nextSize;
149+
}
150+
151+
// Add the merged (or original) free block to both maps
152+
freeBlocksByAddr[blockStart] = blockSize;
153+
freeBlocksByEnd[blockEnd] = blockSize;
46154
// =================================== 作业 ===================================
47155
}
48156

src/core/graph.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,95 @@ namespace infini
151151
// =================================== 作业 ===================================
152152
// TODO:利用 allocator 给计算图分配内存
153153
// HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存
154+
155+
// Track tensor lifetime: when a tensor is last used (by which operator index)
156+
std::unordered_map<TensorObj*, int> tensorLastUse;
157+
std::unordered_map<TensorObj*, size_t> tensorAddress;
158+
159+
// Initialize: all tensors are used at least once initially (for outputs without targets)
160+
for (auto &tensor : tensors)
161+
{
162+
tensorLastUse[tensor.get()] = -1;
163+
}
164+
165+
// Allocate memory for input tensors (tensors without source) first
166+
for (auto &tensor : tensors)
167+
{
168+
if (!tensor->getSource())
169+
{
170+
size_t offset = allocator.alloc(tensor->getBytes());
171+
tensorAddress[tensor.get()] = offset;
172+
}
173+
}
174+
175+
// Calculate last use for each tensor based on the operators
176+
for (size_t i = 0; i < ops.size(); ++i)
177+
{
178+
auto &op = ops[i];
179+
180+
// Check inputs - update their last use time
181+
for (auto &input : op->getInputs())
182+
{
183+
if (input)
184+
{
185+
tensorLastUse[input.get()] = i;
186+
}
187+
}
188+
}
189+
190+
// For output tensors that have no targets, they should live until the end
191+
for (auto &tensor : tensors)
192+
{
193+
if (tensor->getTargets().size() == 0 && tensor->getSource())
194+
{
195+
// This is a graph output, it should live until the end
196+
tensorLastUse[tensor.get()] = ops.size();
197+
}
198+
}
199+
200+
// Process each operator in topological order
201+
for (size_t i = 0; i < ops.size(); ++i)
202+
{
203+
auto &op = ops[i];
204+
205+
// Allocate memory for outputs
206+
for (auto &output : op->getOutputs())
207+
{
208+
if (output && tensorAddress.find(output.get()) == tensorAddress.end())
209+
{
210+
size_t offset = allocator.alloc(output->getBytes());
211+
tensorAddress[output.get()] = offset;
212+
}
213+
}
214+
215+
// Free inputs that are no longer needed after this operator
216+
for (auto &input : op->getInputs())
217+
{
218+
if (input && tensorLastUse[input.get()] == (int)i)
219+
{
220+
// This is the last use of this tensor
221+
if (tensorAddress.find(input.get()) != tensorAddress.end())
222+
{
223+
allocator.free(tensorAddress[input.get()], input->getBytes());
224+
}
225+
}
226+
}
227+
}
228+
229+
// Get the actual memory pointer from allocator
230+
void *basePtr = allocator.getPtr();
231+
232+
// Bind memory to each tensor
233+
for (auto &tensor : tensors)
234+
{
235+
if (tensorAddress.find(tensor.get()) != tensorAddress.end())
236+
{
237+
size_t offset = tensorAddress[tensor.get()];
238+
void *tensorPtr = reinterpret_cast<char*>(basePtr) + offset;
239+
auto blob = make_ref<BlobObj>(runtime, tensorPtr);
240+
tensor->setDataBlob(blob);
241+
}
242+
}
154243
// =================================== 作业 ===================================
155244

156245
allocator.info();

src/operators/concat.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,30 @@ optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) {
1616
// =================================== 作业 ===================================
1717
// TODO:修改 dims,返回正确的 concat 后的 shape
1818
// REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13
19+
20+
// All inputs should have the same shape except for the dimension being concatenated
21+
// Sum up the sizes along the concatenation dimension
22+
int concatDimSize = 0;
23+
for (size_t i = 0; i < inputs.size(); ++i) {
24+
auto inputDims = inputs[i]->getDims();
25+
26+
// Verify that all other dimensions match
27+
if (inputDims.size() != dims.size()) {
28+
return std::nullopt;
29+
}
30+
31+
for (size_t j = 0; j < dims.size(); ++j) {
32+
if ((int)j != dim && inputDims[j] != dims[j]) {
33+
return std::nullopt;
34+
}
35+
}
36+
37+
// Accumulate the size of the concatenation dimension
38+
concatDimSize += inputDims[dim];
39+
}
40+
41+
// Update the output shape
42+
dims[dim] = concatDimSize;
1943
// =================================== 作业 ===================================
2044

2145
return {{dims}};

src/operators/transpose.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,15 @@ namespace infini
3232
// =================================== 作业 ===================================
3333
// TODO:修改 output_dim,返回正确的 transpose 后的 shape
3434
// REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21
35+
36+
// Apply the permutation to get the output shape
37+
for (int i = 0; i < rank; ++i)
38+
{
39+
output_dim[i] = input_dim[transposePermute[i]];
40+
}
3541
// =================================== 作业 ===================================
3642

37-
return std::nullopt;
43+
return {{output_dim}};
3844
}
3945

4046
std::string TransposeObj::toString() const

src/utils/operator_utils.cc

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,47 @@ Shape infer_broadcast(const Shape &A, const Shape &B) {
88
// =================================== 作业 ===================================
99
// TODO:对 A 和 B 进行双向广播,返回广播后的形状。
1010
// REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
11+
12+
// Broadcasting rules:
13+
// 1. If two shapes have different ranks, prepend 1s to the shorter one
14+
// 2. For each dimension, the output dimension is max(dim_A, dim_B)
15+
// 3. Dimensions are compatible if they are equal or one of them is 1
16+
17+
size_t rankA = A.size();
18+
size_t rankB = B.size();
19+
size_t maxRank = std::max(rankA, rankB);
20+
21+
Shape result(maxRank);
22+
23+
// Iterate from the trailing dimensions
24+
for (size_t i = 0; i < maxRank; ++i) {
25+
int dimA = 1, dimB = 1;
26+
27+
// Get dimension from A (if exists)
28+
if (i < rankA) {
29+
dimA = A[rankA - 1 - i];
30+
}
31+
32+
// Get dimension from B (if exists)
33+
if (i < rankB) {
34+
dimB = B[rankB - 1 - i];
35+
}
36+
37+
// Check compatibility and compute output dimension
38+
if (dimA == dimB) {
39+
result[maxRank - 1 - i] = dimA;
40+
} else if (dimA == 1) {
41+
result[maxRank - 1 - i] = dimB;
42+
} else if (dimB == 1) {
43+
result[maxRank - 1 - i] = dimA;
44+
} else {
45+
// Incompatible dimensions
46+
IT_ASSERT(false, "Incompatible broadcast dimensions");
47+
}
48+
}
1149
// =================================== 作业 ===================================
1250

13-
return {};
51+
return result;
1452
}
1553

1654
int get_real_axis(const int &axis, const int &rank) {

0 commit comments

Comments
 (0)