-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdiffusion_client.py
212 lines (179 loc) · 9.19 KB
/
diffusion_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch.autograd.function import once_differentiable
import hivemind
from load_balancer import LoadBalancer, NoModulesFound
from hivemind.moe.client.expert import DUMMY, expert_forward
from hivemind.compression import serialize_torch_tensor
from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
logger = get_logger(__name__)
MAX_PROMPT_LENGTH = 512
MAX_NODES = 99999
@dataclass
class GeneratedImage:
encoded_image: bytes
decoded_image: Optional[np.ndarray]
nsfw_score: float
class DiffusionClient:
def __init__(
self,
*,
initial_peers: List[str],
dht_prefix: str = "diffusion",
**kwargs
):
dht = hivemind.DHT(initial_peers, client_mode=True, start=True, **kwargs)
self.expert = BalancedRemoteExpert(dht=dht, uid_prefix=dht_prefix + ".")
def draw(self, prompts: List[str], *, skip_decoding: bool = False) -> List[GeneratedImage]:
encoded_prompts = []
for prompt in prompts:
tensor = torch.tensor(list(prompt.encode()), dtype=torch.uint8)
tensor = F.pad(tensor, (0, MAX_PROMPT_LENGTH - len(tensor)))
encoded_prompts.append(tensor)
encoded_prompts = torch.stack(encoded_prompts)
encoded_images, nsfw_scores = self.expert(encoded_prompts)
result = []
for buf, nsfw_score in zip(encoded_images.numpy(), nsfw_scores.detach().numpy()):
decoded_image = None
if not skip_decoding:
decoded_image = cv2.imdecode(buf, 1) # imdecode() returns a BGR image
decoded_image = cv2.cvtColor(decoded_image, cv2.COLOR_BGR2RGB)
result.append(GeneratedImage(
encoded_image=buf.tobytes(),
decoded_image=decoded_image,
nsfw_score=nsfw_score,
))
return result
@property
def n_active_servers(self) -> int:
return self.expert.expert_balancer.n_active_experts
class BalancedRemoteExpert(nn.Module):
"""
A torch module that dynamically assigns weights to one RemoteExpert from a pool, proportionally to their throughput.
ToDo docstring, similar to hivemind.RemoteExpert
"""
def __init__(
self,
*,
dht: hivemind.DHT,
uid_prefix: str,
grid_size: Tuple[int, ...] = (1, MAX_NODES),
forward_timeout: Optional[float] = None,
backward_timeout: Optional[float] = None,
update_period: float = 30.0,
backward_task_size_multiplier: float = 2.5,
**kwargs,
):
super().__init__()
if uid_prefix.endswith(".0."):
logger.warning(f"BalancedRemoteExperts will look for experts under prefix {self.uid_prefix}0.")
assert len(grid_size) == 2 and grid_size[0] == 1, "only 1xN grids are supported"
self.dht, self.uid_prefix, self.grid_size = dht, uid_prefix, grid_size
self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
self.backward_task_size_multiplier = backward_task_size_multiplier
self.expert_balancer = LoadBalancer(dht, key=f"{self.uid_prefix}0.", update_period=update_period, **kwargs)
self._expert_info = None # expert['info'] from one of experts in the grid
def forward(self, *args: torch.Tensor, **kwargs: torch.Tensor):
"""
Call one of the RemoteExperts for the specified inputs and return output. Compatible with pytorch.autograd.
:param args: input tensors that will be passed to each expert after input, batch-first
:param kwargs: extra keyword tensors that will be passed to each expert, batch-first
:returns: averaged predictions of all experts that delivered result on time, nested structure of batch-first
"""
assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
if self._expert_info is None:
raise NotImplementedError()
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (args, kwargs)
if not nested_compare(forward_inputs, self.info["forward_schema"]):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
flat_inputs = list(nested_flatten(forward_inputs))
forward_task_size = flat_inputs[0].shape[0]
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
self.expert_balancer,
self.info,
self.forward_timeout,
self.backward_timeout,
forward_task_size,
forward_task_size * self.backward_task_size_multiplier,
*flat_inputs)
return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
@property
def info(self):
while self._expert_info is None:
try:
with self.expert_balancer.use_another_expert(1) as chosen_expert:
self._expert_info = chosen_expert.info
except NoModulesFound:
raise
except Exception:
logger.exception(f"Tried to get expert info from {chosen_expert} but caught:")
return self._expert_info
class _BalancedRemoteModuleCall(torch.autograd.Function):
"""Internal autograd-friendly call of a remote module. For applications, use BalancedRemoteExpert instead."""
@staticmethod
def forward(
ctx,
dummy: torch.Tensor,
expert_balancer: LoadBalancer,
info: Dict[str, Any],
forward_timeout: float,
backward_timeout: float,
forward_task_size: float,
backward_task_size: float,
*inputs: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
ctx.expert_balancer, ctx.info = expert_balancer, info
ctx.forward_timeout, ctx.backward_timeout = forward_timeout, backward_timeout
ctx.forward_task_size, ctx.backward_task_size = forward_task_size, backward_task_size
inputs = tuple(tensor.cpu().detach() for tensor in inputs)
ctx.save_for_backward(*inputs)
serialized_tensors = [
serialize_torch_tensor(inp, proto.compression)
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
]
while True:
try:
with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
logger.info(f"Query served by: {chosen_expert}")
deserialized_outputs = RemoteExpertWorker.run_coroutine(expert_forward(
chosen_expert.uid, inputs, serialized_tensors, chosen_expert.stub))
break
except NoModulesFound:
raise
except Exception:
logger.exception(f"Tried to call forward for expert {chosen_expert} but caught:")
return tuple(deserialized_outputs)
@staticmethod
@once_differentiable
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
raise NotImplementedError("Backward is not yet implemented in this example")
# grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
# inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
# backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
# serialized_tensors = [
# serialize_torch_tensor(tensor, proto.compression)
# for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
# ]
# while True:
# try:
# with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
# backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
# grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
# break
# except NoModulesFound:
# raise
# except Exception:
# logger.exception(f"Tried to call backward for expert {chosen_expert} but caught:")
# deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
# return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)