Skip to content

Commit 4e98e54

Browse files
committed
prosumer: show matching global models as starting point for FL process
1 parent 1437d3a commit 4e98e54

6 files changed

Lines changed: 78 additions & 15 deletions

File tree

islands/prosumer/ProsumerStep3.tsx

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import { BasicSelect, SelectOption, SelectProps } from "@/components/Select.tsx";
22
import { batch, Signal, useComputed, useSignal, useSignalEffect } from "@preact/signals";
33
import classNames from "@/utils/classnames.js";
4-
import { ProsumerWorkflowFLData } from "@/utils/types.ts";
4+
import { ModelSearchResponseItem, ProsumerWorkflowFLData } from "@/utils/types.ts";
55

66
export default function ProsumerStep3(props: {
77
process_name: string;
88
disabled: boolean;
99
fl_process?: ProsumerWorkflowFLData;
10+
global_models: ModelSearchResponseItem[];
1011
}) {
1112
console.log("STEP2 disabled", props.disabled);
1213
// const initial_filters = props.criteria
@@ -64,13 +65,16 @@ export default function ProsumerStep3(props: {
6465
// })
6566
// );
6667

68+
const selected_global_model_id = useSignal<number>(0);
6769
const aggregationRule = useSignal<string>(props.fl_process?.computation ?? "Simple Averaging");
6870
const num_of_fl_rounds = useSignal<number>(props.fl_process?.number_of_rounds ?? 1);
6971
const num_of_iterations = useSignal<number>(props.fl_process?.num_of_iterations ?? 1);
7072
const error_num_of_iterations = useSignal<boolean>(false);
7173
const solver = useSignal<string>(props.fl_process?.solver ?? "HQS"); // ["HQS", "ADMM"]
7274
const denoiser = useSignal<string>(props.fl_process?.denoiser ?? "CNN"); // ["CNN", "Autoencoder", "Transformer"]
7375

76+
console.log(props);
77+
7478
const onChangeAggregationRule = (e: Event) => {
7579
const target = e.target as HTMLInputElement;
7680
console.log("GOT", target, target.value);
@@ -197,6 +201,33 @@ export default function ProsumerStep3(props: {
197201
<label>Number of FL Rounds</label>
198202
<i>numbers</i>
199203
</div>
204+
{props.global_models.length > 0 && (
205+
<>
206+
<h5>Global Models to initialize the FL process</h5>
207+
<div class="grid large-space">
208+
{props.global_models.map((r, index) => (
209+
<div class="secondary-container padding s12 m6 l3" id={`model-${index}`}>
210+
<label class="radio extra">
211+
<input
212+
type="radio"
213+
name="model"
214+
value={r.id}
215+
disabled={props.disabled}
216+
checked={r.id == selected_global_model_id.peek()}
217+
/>
218+
<span>Model {r.name ?? ""} (ID: {r.id} - Size: {r.size ?? ""})</span>
219+
</label>
220+
{r.name && <div>Name: {r.name}</div>}
221+
{r.application_type && <div>Application type: {r.application_type}</div>}
222+
{r.round && <div>Rounds: {r.round}</div>}
223+
{r.input && <div>Input: {r.input}</div>}
224+
{r.output && <div>Output: {r.output}</div>}
225+
{r.nn_architecture && <div>Architecture: {r.nn_architecture}</div>}
226+
</div>
227+
))}
228+
</div>
229+
</>
230+
)}
200231
</p>
201232

202233
{!props.disabled && (

routes/consumer/[consumer_id]/step2.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export default function Step2Page(props: PageProps<Data>) {
6161
const disabled = !!selected_model_id;
6262
const pp = (r: ModelSearchResponseItem) => {
6363
return (
64-
<div class="padding ">
64+
<div class="padding secondary-container s12 m6 l3">
6565
<label class="radio extra">
6666
<input
6767
type="radio"
@@ -70,7 +70,7 @@ export default function Step2Page(props: PageProps<Data>) {
7070
disabled={disabled}
7171
checked={r.id == selected_model_id}
7272
/>
73-
<span>Model {r.name ?? ""} ({r.id} - {r.size ?? ""})</span>
73+
<span>Model {r.name ?? ""} (ID: {r.id} - Size: {r.size ?? ""})</span>
7474
</label>
7575
{r.name && <div>Name: {r.name}</div>}
7676
{r.application_type && <div>Application type: {r.application_type}</div>}
@@ -90,7 +90,7 @@ export default function Step2Page(props: PageProps<Data>) {
9090
<legend>
9191
{disabled ? "You have selected the following model" : "Select one of the models below"}
9292
</legend>
93-
<nav class="vertical">
93+
<nav class="grid large-space">
9494
{results.map(pp)}
9595
</nav>
9696
</fieldset>

routes/prosumer/[prosumer_id]/step1.tsx

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import { Handlers, PageProps } from "$fresh/server.ts";
2-
import { Domain, ProsumerWorkflowData, SSISearchCriterion, SSISearchCriterionOperator, User } from "@/utils/types.ts";
3-
import { dl_domains, do_kg_get_prosumer_data, do_ssi_search } from "@/utils/backend.ts";
2+
import {
3+
Domain,
4+
ModelSearchAttributeCriterion,
5+
ModelSearchCriterion,
6+
ProsumerWorkflowData,
7+
SSISearchCriterion,
8+
SSISearchCriterionOperator,
9+
User,
10+
} from "@/utils/types.ts";
11+
import { dl_domains, do_dl_model_search, do_kg_get_prosumer_data, do_ssi_search } from "@/utils/backend.ts";
412
import { get_user, redirect_to_login, SessionState } from "@/utils/http.ts";
513
import ProsumerStep1 from "@/islands/prosumer/ProsumerStep1.tsx";
614
import { db_get, db_store, set_user_session_data, user_session_data } from "@/utils/db.ts";
@@ -74,13 +82,35 @@ export const handler: Handlers<unknown, SessionState> = {
7482

7583
console.log("CRITERIA", filters);
7684

85+
// Make a similar query to the DL to locate global models
86+
// that can be used as starting points for the FL process
87+
88+
const model_search_map: Map<Domain, ModelSearchAttributeCriterion[]> = new Map();
89+
filters.forEach((filter) => {
90+
if (model_search_map.has(filter.domain)) {
91+
model_search_map.get(filter.domain)?.push({ attribute: filter.attribute, value: filter.value });
92+
} else {
93+
model_search_map.set(filter.domain, [{ attribute: filter.attribute, value: filter.value }]);
94+
}
95+
});
96+
const model_search_criteria: ModelSearchCriterion[] = model_search_map.entries().map(([domain, criteria]) => ({
97+
domain,
98+
attributes: criteria,
99+
})).toArray();
100+
101+
const global_models = (await do_dl_model_search(
102+
user,
103+
model_search_criteria,
104+
)) ?? [];
105+
77106
const w: ProsumerWorkflowData = {
78107
id: prosumer_id,
79108
name: process_name,
80109
ssi: {
81110
status: "NOT STARTED",
82111
process_id: "",
83112
criteria: filters,
113+
global_models: global_models,
84114
},
85115
models_selected: [],
86116
kg_results: [],

routes/prosumer/[prosumer_id]/step3.tsx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { Handlers, PageProps } from "$fresh/server.ts";
2-
import { Domain, ProsumerWorkflowData, ProsumerWorkflowFLData, SSISearchCriterion, User } from "@/utils/types.ts";
3-
import { dl_domains, do_fl_poll, do_fl_submit, do_ssi_poll } from "@/utils/backend.ts";
2+
import { ModelSearchResponseItem, ProsumerWorkflowData, ProsumerWorkflowFLData, User } from "@/utils/types.ts";
3+
import { do_fl_poll, do_fl_submit, do_ssi_poll } from "@/utils/backend.ts";
44
import { get_user, redirect_to_login, SessionState } from "@/utils/http.ts";
5-
import { db_get, db_store, set_user_session_data, user_session_data } from "@/utils/db.ts";
5+
import { db_get, db_store, user_session_data } from "@/utils/db.ts";
66

77
import { redirect } from "@/utils/http.ts";
88
import { prosumer_key } from "@/utils/misc.ts";
@@ -12,6 +12,7 @@ interface Data {
1212
process_name: string;
1313
disabled: boolean;
1414
fl_process?: ProsumerWorkflowFLData;
15+
global_models: ModelSearchResponseItem[];
1516
}
1617

1718
async function user_profile(sessionId: string): Promise<User> {
@@ -117,17 +118,19 @@ export const handler: Handlers<unknown, SessionState> = {
117118
process_name: prosumer_data?.name || "",
118119
fl_process: fl_process_data,
119120
disabled: disabled,
121+
global_models: prosumer_data.ssi.global_models,
120122
});
121123
},
122124
};
123125

124126
export default function Step3Page(props: PageProps<Data>) {
125-
console.log("disabled", props.data.disabled);
127+
console.log("disabled-3", props.data.disabled);
126128
return (
127129
<ProsumerStep3
128130
process_name={props.data.process_name}
129131
disabled={props.data.disabled}
130132
fl_process={props.data.fl_process}
133+
global_models={props.data.global_models}
131134
/>
132135
);
133136
}

routes/prosumer/[prosumer_id]/step4.tsx

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@ import { Handlers, PageProps } from "$fresh/server.ts";
22
import { FLProcessStatusData, ProsumerWorkflowData, ProsumerWorkflowFLData, User } from "@/utils/types.ts";
33
import { do_fl_poll, update_global_model_round } from "@/utils/backend.ts";
44
import { get_user, redirect_to_login, SessionState } from "@/utils/http.ts";
5-
import { db_get, db_store, set_user_session_data, user_session_data } from "@/utils/db.ts";
5+
import { db_get, db_store, user_session_data } from "@/utils/db.ts";
66

77
import { redirect } from "@/utils/http.ts";
88
import { prosumer_key } from "@/utils/misc.ts";
9-
import ProsumerStep3 from "@/islands/prosumer/ProsumerStep3.tsx";
109
import AutoReload from "@/islands/AutoReload.tsx";
11-
import { DATA_CURRENT } from "$fresh/src/constants.ts";
1210

1311
interface Data {
1412
process_name: string;
@@ -28,7 +26,7 @@ async function user_profile(sessionId: string): Promise<User> {
2826
// STEP3: allow the user to select the FL parameters and start the FL process
2927

3028
export const handler: Handlers<unknown, SessionState> = {
31-
async POST(req, ctx) {
29+
async POST(_req, _ctx) {
3230
// const user = await get_user(req, ctx.state.session);
3331
// if (!user) {
3432
// return redirect_to_login(req);
@@ -113,7 +111,7 @@ export const handler: Handlers<unknown, SessionState> = {
113111
const flprocessstatus: FLProcessStatusData = await do_fl_poll(user, fl_process_data.process_id);
114112
const previous_status = prosumer_data.fl_process.status;
115113
if (previous_status.current_round < flprocessstatus.current_round) {
116-
const [global_model_id, dl_error] = await update_global_model_round(
114+
const [global_model_id, _dl_error] = await update_global_model_round(
117115
user,
118116
fl_process_data.process_id,
119117
flprocessstatus.current_round,

utils/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ export interface ProsumerWorkflowSSIData {
6969
process_id: string;
7070
status: SSISearchStatus;
7171
criteria: SSISearchCriterion[];
72+
global_models: ModelSearchResponseItem[]; // Global models matching SSI criteria if found
7273
results?: string[]; // Results, "model ids"
7374
}
7475

0 commit comments

Comments
 (0)