arunppsg / TadGAN

Code for the paper "TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks"
MIT License
156 stars 34 forks source link

prune_false_positive #9

Open 529261027 opened 2 years ago

529261027 commented 2 years ago

Hi, Thanks for the translation of TadGan into pytorch. when I apply prune_false_positive program, I found some confusion about this, i think we should find every abnormal sequences that the max anomaly score ,and the max normal score, so , I modify the code of prune_false_positive, can you help me to determine that the code is correct, Looking forward to your reply

def prune_false_positive(is_anomaly, anomaly_score, change_threshold):
    #The model might detect a high number of false positives.
    #In such a scenario, pruning of the false positive is suggested.
    #Method used is as described in the Section 5, part D Identifying Anomalous
    #Sequence, sub-part - Mitigating False positives
    #TODO code optimization
    seq_details = []
    delete_sequence = 0
    start_position = 0
    end_position = 0
    anomaly_score = np.abs(anomaly_score)  # calculate standard deviations from the mean of the window
    max_seq_element = anomaly_score[0]
    for i in range(1, len(is_anomaly)):
        if is_anomaly[i] == 1 and is_anomaly[i-1] == 0:  # anomaly start
            start_position = i  # anomaly start position
            max_seq_element = anomaly_score[i]  # first anomaly score
        if is_anomaly[i] == 1 and is_anomaly[i-1] == 1 and anomaly_score[i] > max_seq_element:  # continuous anomaly, compare anomaly score
            max_seq_element = anomaly_score[i]
        if i+1 == len(is_anomaly) and is_anomaly[i] == 1:  # last is anomaly
            seq_details.append([start_position, i, max_seq_element, delete_sequence])
        elif is_anomaly[i] == 1 and is_anomaly[i+1] == 0:  # anomaly end
            end_position = i  # anomaly end postion
            seq_details.append([start_position, end_position, max_seq_element, delete_sequence])

    max_elements = list()
    max_elements.append(max(anomaly_score[is_anomaly==0]))  # normal data max score
    for i in range(0, len(seq_details)):
        max_elements.append(seq_details[i][2])

    max_elements.sort(reverse=True)
    max_elements = np.array(max_elements)
    change_percent = abs(max_elements[1:] - max_elements[:-1]) / max_elements[1:]

    # Appending 0 for the 1 st element which is not change percent
    delete_seq = np.append(np.array([0]), change_percent < change_threshold)

    # Mapping max element and seq details
    for i, max_elt in enumerate(max_elements):
        for j in range(0, len(seq_details)):
            if seq_details[j][2] == max_elt:
                seq_details[j][3] = delete_seq[i]

    for seq in seq_details:
        if seq[3] == 1: # Delete sequence
            is_anomaly[seq[0]:seq[1]+1] = [0] * (seq[1] - seq[0] + 1)

    return is_anomaly