Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache
return max_pad;
}
void update_shape_info_tensor(const kernel_impl_params& params) override;
void before_prepare() override;
void cleanup() override;

typed_primitive_inst(network& network, const kv_cache_node& desc);
typed_primitive_inst(network& network) : parent(network), memory_state::variable("") {}

private:
size_t kv_cache_id = 0;
std::vector<std::weak_ptr<memory>> _shallow_outputs;
};

using kv_cache_inst = typed_primitive_inst<kv_cache>;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,10 @@ class primitive_inst {

void reset_events();

virtual void before_prepare() {}
void prepare_primitive();
void execute();
virtual void cleanup() {}
void init_kernels(const kernels_cache& kernels_cache) {
_impl->init_kernels(kernels_cache, *_impl_params);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class typed_primitive_inst<read_value> : public typed_primitive_inst_base<read_v
typed_primitive_inst(network& network) : parent(network), memory_state::variable("") {}

void update_output_memory() override;
void cleanup() override;

protected:
void on_execute() override;
Expand Down
28 changes: 28 additions & 0 deletions src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,32 @@ void kv_cache_inst::update_shape_info_tensor(const kernel_impl_params& params) {
}
}

void kv_cache_inst::before_prepare() {
if (_shallow_outputs.size() != _outputs.size())
_shallow_outputs.resize(_outputs.size());
// if resources has been moved to shallow in previour execution, try recover it
for (size_t i = 0; i < _outputs.size(); ++i) {
auto& shallow_output = _shallow_outputs[i];
auto& output = _outputs[i];
if (!output && !shallow_output.expired()) {
output = shallow_output.lock();
}
shallow_output.reset();
}
}

void kv_cache_inst::cleanup() {
// if there's variable state, it should hold a reference of tensor same as outputs
if (!get_network().has_variable(variable_id()))
return;
if (_shallow_outputs.size() != _outputs.size())
_shallow_outputs.resize(_outputs.size());
// move outputs to shallow, so it can be released when varaiblestate get reset
for (size_t i = 0; i < _outputs.size(); ++i) {
auto& output = _outputs[i];
_shallow_outputs[i] = output;
output.reset();
}
}

} // namespace cldnn
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
for (auto& inst : _exec_order) {
NODE_DEBUG(*inst);

inst->before_prepare();

inst->reset_events();

if (inst->is_input()) {
Expand All @@ -790,6 +792,7 @@ void network::execute_impl(const std::vector<event::ptr>& events) {
// Reset all flags for the next execution
for (auto& inst : _exec_order) {
inst->reset_flags();
inst->cleanup();
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_gpu/src/graph/read_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ void read_value_inst::on_execute() {
update_output_memory();
}

void read_value_inst::cleanup() {
// readvalue simply assign outputs from variablestate,
// does not need to keep reference in outputs after execution
if (!can_be_optimized() || !get_network().has_variable(variable_id()))
return;
for (size_t i = 0; i < _outputs.size() && i < 3; ++i) {
auto& output = _outputs[i];
output.reset();
}
}

void read_value_inst::update_output_memory() {
if (!can_be_optimized() || !get_network().has_variable(variable_id()))
return;
Expand Down