1+ // Copyright 2025 Huawei Technologies Co., Ltd
2+ // Copyright 2024 KVCache.AI
3+ //
4+ // Licensed under the Apache License, Version 2.0 (the "License");
5+ // you may not use this file except in compliance with the License.
6+ // You may obtain a copy of the License at
7+ //
8+ // http://www.apache.org/licenses/LICENSE-2.0
9+ //
10+ // Unless required by applicable law or agreed to in writing, software
11+ // distributed under the License is distributed on an "AS IS" BASIS,
12+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ // See the License for the specific language governing permissions and
14+ // limitations under the License.
15+
16+ #ifndef HETEROGENEOUS_TCP_TRANSPORT_H_
17+ #define HETEROGENEOUS_TCP_TRANSPORT_H_
18+
19+ #include " transport/tcp_transport/tcp_transport.h"
20+ #include " acl/acl.h"
21+ #include < atomic>
22+ #include < new>
23+ #include < condition_variable>
24+
25+ #define HUGE_HOST_SIZE (3ULL * 1024 * 1024 * 1024 )
26+ #define HUGE_DEVICE_SIZE (8 * 1024 * 1024 )
27+ #define HUGE_DEVICE_NUM 4
28+
29+ namespace mooncake {
30+
31+ class HeterogeneousTcpTransport : public Transport {
32+ public:
33+ HeterogeneousTcpTransport ();
34+
35+ ~HeterogeneousTcpTransport ();
36+
37+ int install (std::string &local_server_name,
38+ std::shared_ptr<TransferMetadata> meta,
39+ std::shared_ptr<Topology> topo) override ;
40+
41+ const char *getName () const override { return " ascend" ; }
42+
43+ int registerLocalMemory (void *addr, size_t length,
44+ const std::string &location, bool remote_accessible,
45+ bool update_metadata) override ;
46+
47+ int unregisterLocalMemory (void *addr, bool update_metadata = true ) override ;
48+
49+ int registerLocalMemoryBatch (const std::vector<BufferEntry> &buffer_list,
50+ const std::string &location) override ;
51+
52+ int unregisterLocalMemoryBatch (
53+ const std::vector<void *> &addr_list) override ;
54+
55+ int createStream ();
56+
57+ Status submitTransfer (BatchID batch_id,
58+ const std::vector<TransferRequest> &entries) override ;
59+
60+ Status submitTransferTask (
61+ const std::vector<TransferTask *> &task_list) override ;
62+
63+ Status getTransferStatus (BatchID batch_id, size_t task_id,
64+ TransferStatus &status) override ;
65+ std::unique_ptr<TcpTransport> transport_{};
66+
67+ private:
68+ void transferLoop ();
69+
70+ private:
71+ struct TransferTaskTCP {
72+ std::vector<TransferTask *> tasks;
73+ uint64_t total_length;
74+ uint64_t devId;
75+
76+ TransferTaskTCP (TransferTaskTCP &&) = default ;
77+ TransferTaskTCP &operator =(TransferTaskTCP &&) = default ;
78+
79+ TransferTaskTCP (const TransferTaskTCP &) = delete ;
80+ TransferTaskTCP &operator =(const TransferTaskTCP &) = delete ;
81+
82+ TransferTaskTCP (std::vector<TransferTask *> taskList, uint64_t len,
83+ uint64_t id)
84+ : tasks(std::move(taskList)), total_length(len), devId(id) {}
85+ };
86+ bool running_ = false ;
87+ aclrtStream stream_;
88+ void *hostAddr_ = nullptr ;
89+ void *devAddr_ = nullptr ;
90+ std::vector<void *> hugeDevAddrs;
91+ int deviceLogicId_;
92+ bool firstSubmit_ = true ;
93+ std::mutex memcpy_mutex_;
94+ uint64_t offset_ = 0 ;
95+ std::thread transferThread_;
96+ std::queue<TransferTaskTCP> transferQueues_;
97+ std::mutex transfer_mutex_;
98+ std::condition_variable transfer_cond_;
99+ std::atomic<int > transfer_counter_{0 };
100+ int devId_ = 0 ;
101+ std::array<bool , HUGE_DEVICE_NUM> mem_blocks = {false , false , false , false };
102+ std::mutex dev_mtx_;
103+ std::condition_variable dev_cv_;
104+ };
105+
106+ using TransferRequest = Transport::TransferRequest;
107+ using TransferStatus = Transport::TransferStatus;
108+ using TransferStatusEnum = Transport::TransferStatusEnum;
109+ using SegmentID = Transport::SegmentID;
110+ using BatchID = Transport::BatchID;
111+
112+ } // namespace mooncake
113+
114+ #endif // HETEROGENEOUS_TCP_TRANSPORT_H_
0 commit comments