Source code for qutree.plot


import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from qutree.ttn.grid import *
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import imageio
import os

[docs] def plot_xyz(xyz, f, ranges = None): # Create a 3D scatter plot with colors based on f_values xyz = xyz.grid fig = go.Figure(data=[go.Scatter3d( x=xyz[:, 0], y=xyz[:, 1], z=xyz[:, 2], mode='markers', marker=dict( size=5, color=f, # Use f_values as colors colorscale='Viridis', # Choose a colorscale colorbar=dict(title='Function Value') # Add colorbar with a title ) )]) # Set plot layout with fixed axis range if ranges is None: fig.update_layout( scene=dict( xaxis=dict(title='X Axis', autorange=True), # Fix x-axis range to [0, 1] yaxis=dict(title='Y Axis', autorange=True), # Fix y-axis range to [0, 1] zaxis=dict(title='Z Axis', autorange=True), # Fix z-axis range to [0, 1] aspectmode='cube', # Ensure that aspect ratio is maintained ), margin=dict(l=0, r=0, b=0, t=0) ) else: fig.update_layout( scene=dict( xaxis=dict(title='X Axis', range=ranges[0], autorange=False), # Fix x-axis range to [0, 1] yaxis=dict(title='Y Axis', range=ranges[1], autorange=False), # Fix y-axis range to [0, 1] zaxis=dict(title='Z Axis', range=ranges[2], autorange=False), # Fix z-axis range to [0, 1] aspectmode='cube', # Ensure that aspect ratio is maintained ), margin=dict(l=0, r=0, b=0, t=0) ) # Show the plot return fig
[docs] def plot_tt_diagram(tn, draw_ranks = True): nleaves = len(up_leaves(tn)) pos0 = [np.array([-(-i - 1), 0.1]) for i in range(-nleaves, 0)] pos1 = [np.array([-i, 0]) for i in range(nleaves)] pos12 = pos0 + pos1 pos = {i: pos12[i] for i in range(-nleaves, nleaves)} # plt.draw() plt.gca().set_aspect(15) # 'equal' ensures that one unit in x is equal to one unit in y fig = nx.draw(tn, pos = pos, with_labels=False, node_size = 250, font_size = 8) if draw_ranks: ranks = nx.get_edge_attributes(tn, 'r') nx.draw_networkx_edge_labels(tn, pos, edge_labels=ranks, font_size=14) # plt.subplots_adjust(left=0.1, right=0.9, bottom=10.1, top=20.9) return fig
[docs] def plot_tn_xyz(tn, fun, q_to_x = None): # collect grids gs = nx.get_node_attributes(tn, 'grid') gs = list(gs.values()) # dict to list grid = direct_sum(gs) if not q_to_x is None: grid = grid.transform(q_to_x) fig = plot_xyz(grid, grid.evaluate(fun)) fig.show()
[docs] def tn_to_df(tn, fun): # collect grids gs = nx.get_node_attributes(tn, 'grid') gs = list(gs.values()) # dict to list xyz, f, node = [], [], [] for id, grid in enumerate(gs): for i, point in enumerate(grid.grid): xyz.append(point) node.append(id) return pd.DataFrame({'xyz': xyz, 'node': node})
[docs] def plot_tree(G, draw_ranks = True): G = add_layer_index(G) nleaves = len(up_leaves(G)) grid = np.linspace(0, 1, nleaves) pos = {i : (0, 0) for i in sorted(G.nodes)} for node in sorted(G.nodes): layer = G.nodes[node]["layer"] y = -layer x = 0. if node < 0: id = -node - 1 x = grid[id] else: cs = children(G, node) children_pos = np.array([pos[c][0] for c in cs]) x = np.mean(children_pos, axis=0) pos[node] = (x, y) nx.draw(G, pos = pos, with_labels=False, node_size = 500) if draw_ranks: ranks = nx.get_edge_attributes(G, 'r') nx.draw_networkx_edge_labels(G, pos, edge_labels=ranks, font_size = 18) plt.draw()
[docs] def tngrid_to_df(G, O): # Get node attributes as a dictionary node_attributes = nx.get_node_attributes(G, 'grid') node_attributes = {k: v.grid for k, v in node_attributes.items() if v is not None} # Convert dictionary to DataFrame df = pd.DataFrame(list(node_attributes.items()), columns=['node', 'grid']) # Create a new DataFrame with the exploded values and reset the index df = df.explode('grid').reset_index() df['f'] = df['grid'].apply(lambda x: O.Err(x)) # Remove the 'index' column df = df.drop(columns='index') # Convert the 'grid' column into separate columns L = len(df['grid'].iloc[0]) df[['x{}'.format(i+1) for i in range(L)]] = df['grid'].apply(pd.Series) # Drop the original 'grid' column df.drop(columns='grid', inplace=True) return df
[docs] def concat_pandas(dfs): for t, df in enumerate(dfs): df['time'] = t df = pd.concat(dfs).reset_index() df["size"] = 1 return df
[docs] def grid_animation(df, color = 'f'): fig = px.scatter_3d(df, x="x1", y="x2", z="x3", animation_frame="time", animation_group="node", size="size", color=color, hover_name="time", size_max=15, width=1000, height=800) fig.layout.updatemenus[0].buttons[0].args[1]['frame']['duration'] = 100 fig.layout.updatemenus[0].buttons[0].args[1]['transition']['duration'] = 20 fig.update_layout(scene=dict( xaxis_title="x", yaxis_title="y", zaxis_title="z", xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False), zaxis=dict(showticklabels=False), )) camera = dict( eye=dict(x=1.5, y=-1.0, z=1.2), # Set the position of the camera center=dict(x=0, y=0, z=0), # Set the point the camera is looking at up=dict(x=0, y=0, z=1) # Set the upward direction of the camera ) fig.update_layout(scene_camera=camera) return fig
[docs] def grid_animation_to_gif(df, color='f', gif_filename='animation.gif', frames_folder = '.frames'): os.makedirs(frames_folder, exist_ok=True) unique_times = df['time'].unique() unique_times = unique_times[2:35] for time in unique_times: sub_df = df[df['time'] == time] # Generate the plot fig = px.scatter_3d(sub_df, x="x1", y="x2", z="x3", color=color, size_max=10, width=1000, height=800) fig.update_layout(scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z", xaxis=dict(showticklabels=False), yaxis=dict(showticklabels=False), zaxis=dict(showticklabels=False))) camera = dict( eye=dict(x=1.5, y=-1.0, z=1.2), # Set the position of the camera center=dict(x=0, y=0, z=0), # Set the point the camera is looking at up=dict(x=0, y=0, z=1) # Set the upward direction of the camera ) # Adjust transparency settings if needed fig.update_traces(marker=dict(opacity=0.75, size = 6)) # Adjust opacity here fig.update_layout(scene_camera=camera) fig.write_image(os.path.join(frames_folder, f'frame_{time:04d}.png'), format = 'png') # Generate GIF images = [imageio.imread(os.path.join(frames_folder, f'frame_{i:04d}.png')) for i in unique_times] imageio.mimsave(gif_filename, images, fps=10) # Adjust fps as needed # Cleanup images for filename in os.listdir(frames_folder): os.remove(os.path.join(frames_folder, filename)) os.rmdir(frames_folder) return gif_filename