Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 142 additions & 103 deletions howso/ablation.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -610,37 +610,34 @@
))

;mark each case as not being kept at first
(if (> (size (contained_entities (query_exists ".keeping"))) 0)
(map
(lambda
(assign_to_entities (current_value) (assoc ".keeping" .false))
)
all_case_ids
)

(if (= (size (contained_entities (query_exists ".keeping"))) 0)
;create the label because it doesn't exist yet
(map
(lambda
(accum_entity_roots (current_value) (zip_labels
[".keeping"] [.false]
[".keeping"] [(null)]
))
)
all_case_ids
)
)

;store the surprisal to each cases's most similar neighbor
(call !StoreCaseValues (assoc
case_values_map
(declare (assoc
min_neighbor_surprisals
(map
(lambda (apply "min" (values (current_value))))
neighbor_surprisals_map
)
))

;store the surprisal to each cases's most similar neighbor
(call !StoreCaseValues (assoc
case_values_map min_neighbor_surprisals
label_name ".neighbor_surprisal"
))

(declare (assoc
;map of case to its core-set surprisal (the max surprisal to any case in the coreset for all cases)
;map of case to its core-set surprisal (the min surprisal to any case in the coreset for all cases)
;set it to an extremely large value
case_to_css_map (map 10e13 (zip all_case_ids))

Expand Down Expand Up @@ -668,38 +665,37 @@
(if (= (current_index 1) 0)
;on first iteration, just take lowest DC case
(contained_entities
(query_equals ".keeping" .false)
(query_equals ".keeping" (null))
(query_min ".neighbor_surprisal" 1 .true)
)

;otherwise need cases with low neighbor surprisal (ns) that is far from its most similar case in current_cases_to_keep
(let
(assoc
lowest_ns_cases
(contained_entities
(query_equals ".keeping" .false)
(if lowest_ns_cases_trunc_n
(query_min ".neighbor_surprisal" lowest_ns_cases_trunc_n .true)
)
)
)

(declare (assoc
low_ns_case_scores
low_ns_case_scores_map
(map
(lambda
;coreset surprisal / neighbor surprisal, the smaller the neighbor surprisal the larger this score
(/
(get case_to_css_map (current_index))
(retrieve_from_entity (current_index) ".neighbor_surprisal")
(get min_neighbor_surprisals (current_index))
)
)
(compute_on_contained_entities
(query_equals ".keeping" (null) )
(if lowest_ns_cases_trunc_n
(query_min ".neighbor_surprisal" lowest_ns_cases_trunc_n .true)
)
)
(zip lowest_ns_cases)
)
))
)

(declare (assoc
coreset_size (size (contained_entities (query_equals ".keeping" .true)))
coreset_size
(if lowest_ns_cases_trunc_n
(size (contained_entities (query_not_equals ".keeping" (null) )))
(- num_cases (size low_ns_case_scores_map))
)
))

(declare (assoc
Expand All @@ -722,7 +718,7 @@
(if (and
(>= coreset_size !autoAblationMinNumCases)
(<
(apply "max" (values low_ns_case_scores))
(apply "max" (values low_ns_case_scores_map))
ratio_cutoff_value
)
)
Expand All @@ -731,27 +727,28 @@

;sorting low dc cases by *decreasing* "score" and return the right amount
(if (= 1 num_cases_to_keep)
(trunc (index_max low_ns_case_scores) 1)
(trunc (index_max low_ns_case_scores_map) 1)

(sort
(lambda
(-
(get low_ns_case_scores (current_value 1))
(get low_ns_case_scores (current_value))
(get low_ns_case_scores_map (current_value 1))
(get low_ns_case_scores_map (current_value))
)
)
lowest_ns_cases
(indices low_ns_case_scores_map)
num_cases_to_keep
)
)
)
)
iteration (current_index 1)
)

;mark new cases to keep
;mark new cases to keep by setting the iteration value when these cases were selected into the core set
(map
(lambda
(assign_to_entities (current_value) (assoc ".keeping" .true ))
(assign_to_entities (current_value) (assoc ".keeping" iteration))
)
cases_to_add
)
Expand Down Expand Up @@ -779,90 +776,29 @@
)
)
;all non-coreset cases
(zip (contained_entities (query_equals ".keeping" .false)))
(compute_on_contained_entities (query_equals ".keeping" (null)) )
)
))

(if (>=
(size (contained_entities
(query_equals ".keeping" .true)
))
(size (contained_entities (query_not_equals ".keeping" (null)) ))
reduce_max_cases
)
(assign (assoc done .true))
)
)
)

;the list of case ids to be removed are those that did not have an iteration value assigned to ".keeping"
(declare (assoc
;the list of case ids to be removed
cases_to_remove (contained_entities (query_equals ".keeping" .false) )
cases_to_remove (contained_entities (query_equals ".keeping" (null)) )
))

(if !tsTimeFeature
(let
(assoc
entire_series_removal_id_queries
;select those series identifiers where there will be less than 3 cases remaining after this removal pass
;because these entire series should be removed at that point
(filter
(lambda
(<
(size (contained_entities
(query_not_in_entity_list cases_to_remove)
;(current_value) is in the format of (list (query_equals "series_feature_id" value) ... ) for all affected series ids
(current_value)
))
3
)
)
;a list of affected series identifier queries for this batch of 'cases'
(call !GenerateUniqueSeriesQueries (assoc
series_id_features (get !tsFeaturesMap "series_id_features")
case_ids
;of the selected cases, only keep those that were either the first or last case from a series
(append
(contained_entities
(query_in_entity_list cases_to_remove)
(query_equals ".reverse_series_index" 0)
)
(contained_entities
(query_in_entity_list cases_to_remove)
(query_equals ".series_index" 0)
)
)
))
)
)

;do not remove first (.series_index == 0) or last (.reverse_series_index == 0) cases for any series
(assign (assoc
cases_to_remove
(contained_entities
(query_in_entity_list cases_to_remove)
(query_not_equals ".reverse_series_index" 0)
(query_not_equals ".series_index" 0)
)
))

;there were series that will need to be entirely removed, add all those series cases for removal
(if (size entire_series_removal_id_queries)
(accum (assoc
cases_to_remove
(apply "append" (map
(lambda
(contained_entities
(query_not_in_entity_list cases_to_remove)
(current_value)
)
)
entire_series_removal_id_queries
))
))
)
)
(call !FilterCasesToRemoveForTimeSeries (assoc ensure_enough_to_remove .true))
)


(if (size cases_to_remove)
(call !RemoveCases (assoc
cases cases_to_remove
Expand Down Expand Up @@ -897,6 +833,109 @@
(call !Return (assoc payload output))
)

;Helper method in reduce_data used to modify 'cases_to_remove' for time series datasets by
;removing all series edge (first and last in a series) cases from 'cases_to_remove',
;unless a remaining series will have less than 3 cases after removal, then this will instead
;add the remainder case(s) to 'cases_to_remove' so the entire series is removed.
;
;parameters:
; ensure_enough_to_remove: flag, if true checks whether enough 'cases_to_remove' have been selected (within 5% of desired amount),
; and if not, pads them from the core set and reruns this method
#!FilterCasesToRemoveForTimeSeries
(let
(assoc
entire_series_removal_id_queries
;select those series identifiers where there will be less than 3 cases remaining after this removal pass
;because these entire series should be removed at that point
(filter
(lambda
(<
(size (contained_entities
(query_not_in_entity_list cases_to_remove)
;(current_value) is in the format of (list (query_equals "series_feature_id" value) ... ) for all affected series ids
(current_value)
))
3
)
)
;a list of affected series identifier queries for this batch of 'cases'
(call !GenerateUniqueSeriesQueries (assoc
series_id_features (get !tsFeaturesMap "series_id_features")
case_ids
;of the selected cases, only keep those that were either the first or last case from a series
(append
(contained_entities
(query_in_entity_list cases_to_remove)
(query_equals ".reverse_series_index" 0)
)
(contained_entities
(query_in_entity_list cases_to_remove)
(query_equals ".series_index" 0)
)
)
))
)
)

(declare (assoc num_cases_to_remove (size cases_to_remove) ))

;do not remove first (.series_index == 0) or last (.reverse_series_index == 0) cases for any series
(assign (assoc
cases_to_remove
(contained_entities
(query_in_entity_list cases_to_remove)
(query_not_equals ".reverse_series_index" 0)
(query_not_equals ".series_index" 0)
)
))

;there were series that will need to be entirely removed, add all those series cases for removal
(if (size entire_series_removal_id_queries)
(accum (assoc
cases_to_remove
(apply "append" (map
(lambda
(contained_entities
(query_not_in_entity_list cases_to_remove)
(current_value)
)
)
entire_series_removal_id_queries
))
))
)

;if the number of cases to remove has gone down by more than 5%, use the latest cases added to the core set to pad cases_to_remove
(if ensure_enough_to_remove
(let
(assoc
percent_edge_cases_filtered_out (/ (- num_cases_to_remove (size cases_to_remove)) num_cases_to_remove)
)

(if (> percent_edge_cases_filtered_out 0.05)
(seq

(accum (assoc
cases_to_remove
(contained_entities
;find the latest cases added to the core set, those with the highest iteration values
(query_max
".keeping"
;pad cases_to_remove with the extra 'percent_edge_cases_filtered_out', assuming that,
;on average, that many will be filtered out again from these padded cases
(* (- num_cases_to_remove (size cases_to_remove)) (+ 1 percent_edge_cases_filtered_out))
)
)
))

;rerun the filter, but prevent getting stuck in a loop by only doing it once more
(call !FilterCasesToRemoveForTimeSeries (assoc ensure_enough_to_remove .false))
)
)
)
)
)


;helper method that merges duplicate cases during the reduce_data flow
#!ReduceMergeDuplicateCases
Expand Down
Loading