diff --git a/app.py b/app.py index bdafa3b..02419c1 100644 --- a/app.py +++ b/app.py @@ -130,64 +130,68 @@ if files_uploaded or 'files_already_uploaded' in st.session_state: df.rename({'type': 'relation_type'}, inplace=True, axis=1) # 'type' can't be used as attribute. df.columns = [i.lower() for i in df.columns] # Remove capital letters from column names. - with st.form('select_columns'): - # Find and store target column. - target = find_columns('target', df.columns.tolist()) + # Find and store target column. + target = find_columns('target', df.columns.tolist()) - # Find and store source column. - source = find_columns('source', df.columns.tolist()) - - # Confirm choise. - columns_selected = st.form_submit_button("Done") + # Find and store source column. + source = find_columns('source', df.columns.tolist()) - if columns_selected: - # Remove source and target columns from list of options. - columns = df.columns.tolist() - columns.remove(st.session_state["target"]) - columns.remove(st.session_state["source"]) - - if all([st.session_state["source"] != "", st.session_state["target"] != ""]): - source = st.session_state["source"] - target = st.session_state["target"] + # Remove source and target columns from list of options. + columns = df.columns.tolist() + columns.remove(st.session_state["target"]) + columns.remove(st.session_state["source"]) + + if all([st.session_state["source"] != "", st.session_state["target"] != ""]): - # Let the user chose what columns that should be included. - chosen_columns = st.multiselect( - label="Chose other columns to include.", options=columns, default=columns - ) + source = st.session_state["source"] + target = st.session_state["target"] - if csv_nodes != None: # When a nodes file is uploaded. - df_nodes = pd.read_csv(csv_nodes, sep=st.session_state["sep"]) - df_nodes.columns = [i.lower() for i in df_nodes.columns] # Remove capital letters from column names. + # Let the user chose what columns that should be included. + chosen_columns = st.multiselect( + label="Chose other columns to include.", options=columns, default=columns + ) - label_column = find_columns('label', df_nodes.columns.tolist()) - df_nodes.set_index(label_column, inplace=True) + if csv_nodes != None: # When a nodes file is uploaded. + df_nodes = pd.read_csv(csv_nodes, sep=st.session_state["sep"]) + df_nodes.columns = [i.lower() for i in df_nodes.columns] # Remove capital letters from column names. - else: # If no node file provided. - nodes = list(set(df[source].tolist() + df[target].tolist())) - df_nodes = pd.DataFrame( - nodes, index=range(0, len(nodes)), columns=["labels"] - ) - df_nodes.set_index("labels", inplace=True) - # Make empty graph. - G = nx.MultiDiGraph() - # Add nodes. - G = add_nodes(G, df_nodes) - # Add edges. - G = add_edges( - G, df, source=source, target=target, chosen_columns=chosen_columns - ) - - # Turn the graph into a string. - graph_text = "\n".join([line for line in nx.generate_gexf(G)]) + st.session_state['label_column'] = find_columns('label', df_nodes.columns.tolist()) - # Download gexf-file. - gexf_file = "output.gexf" - st.download_button( - "Download gexf-file", graph_text, file_name=gexf_file + if st.session_state['label_column'] != '': + df_nodes.set_index(st.session_state['label_column'], inplace=True) + + else: # If no node file provided. + nodes = list(set(df[source].tolist() + df[target].tolist())) + df_nodes = pd.DataFrame( + nodes, index=range(0, len(nodes)), columns=["labels"] ) - st.write('Import the file to Gephi/Gephi Light, or try [Gephisto](https://jacomyma.github.io/gephisto/) to get an idea of the network.') + + st.session_state['label_column'] = 'labels' + + if st.session_state['label_column'] != '' and df_nodes.index.name != st.session_state['label_column']: + df_nodes.set_index(st.session_state['label_column'], inplace=True) + + + # Make empty graph. + G = nx.MultiDiGraph() + # Add nodes. + G = add_nodes(G, df_nodes) + # Add edges. + G = add_edges( + G, df, source=source, target=target, chosen_columns=chosen_columns + ) + + # Turn the graph into a string. + graph_text = "\n".join([line for line in nx.generate_gexf(G)]) + + # Download gexf-file. + gexf_file = "output.gexf" + st.download_button( + "Download gexf-file", graph_text, file_name=gexf_file + ) + st.write('Import the file to Gephi/Gephi Light, or try [Gephisto](https://jacomyma.github.io/gephisto/) to get an idea of the network.') # except: # st.markdown(':red[Something went wrong, please try again or [write to me](https://twitter.com/lasseedfast).]') \ No newline at end of file