HRNet / Lite-HRNet

This is an official pytorch implementation of Lite-HRNet: A Lightweight High-Resolution Network.
Apache License 2.0
819 stars 126 forks source link

need a feature that can draw line plots #89

Open hanlonegen opened 4 months ago

hanlonegen commented 4 months ago

need a feature that can draw line plots from log files to reflect how the model behaves on the validation set during training

JacobBai123 commented 4 months ago

this may help

import os
import re
import matplotlib.pyplot as plt
import mplcursors

def extract_float_numbers(line):
    """
    Extracts a floating-point number from the line containing the first floating-point number after "AP: ".
    """
    match = re.search(r'AP: (\d+\.\d+)', line)
    if match:
        return float(match.group(1))
    else:
        return None

def read_log_file(file_path):
    """
    Read log files and extract floating-point numbers.
    """
    float_numbers = []
    with open(file_path, 'r') as file:
        for line in file:
            if "Epoch(val)" in line:
                number = extract_float_numbers(line)
                if number is not None:
                    float_numbers.append(number)
    return float_numbers

def plot_line_chart(float_numbers):
    """
    Draw a line chart.
    """
    _, ax = plt.subplots()
    ax.plot(range(len(float_numbers)), float_numbers)
    ax.set(xlabel='Index', ylabel='Value', title='Line Chart')
    ax.set_xlim(0, 25)
    ax.set_ylim(0.4, 0.75)

    # Labeled data point
    for i, value in enumerate(float_numbers):
        ax.annotate(f'{value:.2f}', (i, value), textcoords="offset points", xytext=(0,10), ha='center', fontsize=2)

    # Add interactive tags

colorss = ['g', 'r', 'c', 'm', 'y', 'k', 'b']
offs = -20
def main():
    # Obtain the.log file in the current directory
    log_files = [f for f in os.listdir() if f.endswith('.log')]

    _, ax = plt.subplots()
    ax.set(xlabel='epoch/10', ylabel='AP', title='Line Chart')
    ax.set_xlim(0, 25)
    ax.set_ylim(0.5, 0.75)

    for j, pa in enumerate(log_files):
        log_file_path = pa
        float_numbers = read_log_file(log_file_path)
        ax.plot(range(len(float_numbers)), float_numbers, color=colorss[j % len(colorss)])
        for i, value in enumerate(float_numbers):
            ax.annotate(f'{value:.3f}', (i, value), textcoords="offset points", xytext=(0, offs + 40 * j), ha='center')

    mplcursors.cursor(hover=True)

    plt.show()

if __name__ == "__main__":
    main()