Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slider in scatter_3d and scatter makes some data points go missing #4768

Open
Vilin97 opened this issue Sep 16, 2024 · 0 comments
Open

Slider in scatter_3d and scatter makes some data points go missing #4768

Vilin97 opened this issue Sep 16, 2024 · 0 comments

Comments

@Vilin97
Copy link

Vilin97 commented Sep 16, 2024

Only 2 out of 4 categories are plotted when I use a slider. Other data points do not appear at all. When I slide the slider, different categories are plotted. E.g. in the MWE below, only TP and FP show up when the slider is below 0.9. At 0.9 only TN and FN show up.

This behavior also happens for both 2d and 3d scatter plots. See the MWE below for 3d.

image

import numpy as np
import pandas as pd
import plotly.express as px

def plot_scatter_3d_mwe():
    # Create a small DataFrame with fake data
    data = {
        'Dim1': np.random.rand(10),
        'Dim2': np.random.rand(10),
        'Dim3': np.random.rand(10),
        'due': [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        'serial_number': range(10),
        'predicted_probabilities': [0.9, 0.8, 0.4, 0.2, 0.6, 0.7, 0.1, 0.5, 0.3, 0.95]
    }

    df = pd.DataFrame(data)
    thresholds = np.arange(0, 1.1, 0.1)
    all_frames = []

    for threshold in thresholds:
        # Recalculate predictions based on the threshold
        predicted = (df['predicted_probabilities'] >= threshold).astype(int)
        
        # Create the 4 categories for coloring: TP, TN, FP, FN
        conditions = [
            (df['due'] == 1) & (predicted == 1),  # TP
            (df['due'] == 0) & (predicted == 0),  # TN
            (df['due'] == 0) & (predicted == 1),  # FP
            (df['due'] == 1) & (predicted == 0),  # FN
        ]
        categories = ['TP', 'TN', 'FP', 'FN']
        
        # Assign the categories to a new column
        df['category'] = np.select(conditions, categories, default='Unknown')
        df['threshold'] = threshold  # Add threshold as a column for animation frame
        
        all_frames.append(df.copy())

    # Concatenate all frames for animation
    df_all_frames = pd.concat(all_frames)

    # Plot the scatter 3D with the categories as color and animate over thresholds
    fig = px.scatter_3d(df_all_frames,
                        x='Dim1', y='Dim2', z='Dim3',
                        color='category',
                        animation_frame='threshold',
                        animation_group='serial_number')

    fig.show()

# Call the function
plot_scatter_3d_mwe()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant