Hi,
Thanks for the translation of TadGan into pytorch.
when I apply prune_false_positive program, I found some confusion about this, i think we should find every abnormal sequences that the max anomaly score ,and the max normal score, so , I modify the code of prune_false_positive, can you help me to determine that the code is correct,
Looking forward to your reply
def prune_false_positive(is_anomaly, anomaly_score, change_threshold):
#The model might detect a high number of false positives.
#In such a scenario, pruning of the false positive is suggested.
#Method used is as described in the Section 5, part D Identifying Anomalous
#Sequence, sub-part - Mitigating False positives
#TODO code optimization
seq_details = []
delete_sequence = 0
start_position = 0
end_position = 0
anomaly_score = np.abs(anomaly_score) # calculate standard deviations from the mean of the window
max_seq_element = anomaly_score[0]
for i in range(1, len(is_anomaly)):
if is_anomaly[i] == 1 and is_anomaly[i-1] == 0: # anomaly start
start_position = i # anomaly start position
max_seq_element = anomaly_score[i] # first anomaly score
if is_anomaly[i] == 1 and is_anomaly[i-1] == 1 and anomaly_score[i] > max_seq_element: # continuous anomaly, compare anomaly score
max_seq_element = anomaly_score[i]
if i+1 == len(is_anomaly) and is_anomaly[i] == 1: # last is anomaly
seq_details.append([start_position, i, max_seq_element, delete_sequence])
elif is_anomaly[i] == 1 and is_anomaly[i+1] == 0: # anomaly end
end_position = i # anomaly end postion
seq_details.append([start_position, end_position, max_seq_element, delete_sequence])
max_elements = list()
max_elements.append(max(anomaly_score[is_anomaly==0])) # normal data max score
for i in range(0, len(seq_details)):
max_elements.append(seq_details[i][2])
max_elements.sort(reverse=True)
max_elements = np.array(max_elements)
change_percent = abs(max_elements[1:] - max_elements[:-1]) / max_elements[1:]
# Appending 0 for the 1 st element which is not change percent
delete_seq = np.append(np.array([0]), change_percent < change_threshold)
# Mapping max element and seq details
for i, max_elt in enumerate(max_elements):
for j in range(0, len(seq_details)):
if seq_details[j][2] == max_elt:
seq_details[j][3] = delete_seq[i]
for seq in seq_details:
if seq[3] == 1: # Delete sequence
is_anomaly[seq[0]:seq[1]+1] = [0] * (seq[1] - seq[0] + 1)
return is_anomaly