1. Anuncie Aqui ! Entre em contato fdantas@4each.com.br

[Python] St.rerun() in Streamlit doesn't function as expected

Discussão em 'Python' iniciado por Stack, Outubro 7, 2024 às 00:32.

  1. Stack

    Stack Membro Participativo

    I have a two script module. The first script where the things kick off, is main.py as below

    main.py

    import streamlit as st
    from src.platform_intelligence.language.process_input import process_chat_response, tools
    from src.utils.display_utility import stream_message, get_user_input

    def run_chatbot():
    # Modularize chatbot UI elements

    """Run the Streamlit chatbot interface."""
    language = st.sidebar.radio("Choose Response Language", ('English', 'German', 'French'), horizontal=True)

    if "messages" not in st.session_state:
    st.session_state.messages = [{"role": "assistant", "content": "Hello , I'm Eliza, your Segmentation Assistant. I can currently help you create segments."}]

    if "current_stage" not in st.session_state:
    st.session_state.current_stage = 'ask'

    for message in st.session_state.messages:
    if message["role"] == "assistant":
    with st.chat_message(message["role"], avatar = ""):
    st.markdown(f'<div style="color: black;">{message["content"]}</div>', unsafe_allow_html=True)
    else:
    with st.chat_message(message["role"]):
    st.markdown(message["content"])

    if st.session_state.current_stage == 'ask':
    # Handle chatbot flow using get_user_input
    if user_input:= st.chat_input("Ask Something", key = "main_chat_key", disabled = False):
    with st.chat_message("user", avatar=None):
    st.markdown(user_input)
    assistant_response = process_chat_response(user_input, tools = tools, language = language).strip()
    # #st.session_state.messages.append({"role": "user", "content": user_input})
    # Stream assistant's response in real-time
    st.session_state.messages.append({"role": "user", "content": user_input})
    stream_message(assistant_response)
    print(f"stage of code: From process()-> main: {st.session_state.current_stage}")
    st.session_state.current_stage = "ask"
    st.rerun()
    st.stop()


    As u can see, it only does st.chat_input if st.session_state.current_stage = "ask" which happens right when the app is started.

    it then passes the code to process_chat_response() from process_input.py which is as below. The process_chat_response()-end of below script, works fine and it transfers the code to validate_input_query() in the same script as below.

    process_input.py

    def validate_input_query(merged_arguments: Dict[str, Any], schema_model: BaseModel = InputQuery) -> BaseModel:
    """
    Processes the chat response from a previous LLM Agent and validate it as per schema_model

    Args:
    merged_arguments (Dict): The parsed arguments from previous LLM
    schema_model (BaseModel): InputQuery Schema

    Returns:
    BaseModel: Verified Pydantic Model
    """

    if "current_stage" not in st.session_state:
    st.session_state.current_stage = "validate"

    # Initialize or retrieve stored errors
    if "errors" not in st.session_state:
    st.session_state.errors = []


    errors = {}
    print(f"Current session_state: {st.session_state.current_stage}")
    print(f"stage of code: Entered Validate_query()")
    # Validation Stage
    if st.session_state.current_stage == "validate":
    st.chat_input("...", disabled=True)
    try:
    validated_query = schema_model(**merged_arguments)
    # Convert the validated Pydantic model into a dictionary for easier iteration
    validated_data = validated_query.dict()
    message = "Great! Here's what we learnt from your intent:\n\n"
    # Stream each key-value pair one by one with bold keys
    for key, value in validated_data.items():
    message += f"**{key}**: {value}\n\n" # Key is bolded

    stream_message(message, avatar="") # Stream the full message to the chat

    print(f"Before stage change: {st.session_state.current_stage}")

    st.session_state.current_stage = "ask_additional" # Mark validation as complete

    print(f"After stage change: {st.session_state.current_stage}")

    print(f'validated query in validate_input_query:{validated_query}\n')

    print()
    print(f"stage of code: From Validate_query() to Ask_Additional()")

    st.rerun()

    except ValidationError as e:
    # Collect all validation errors without prompting for input right away
    st.session_state.current_stage = "correct"

    print(f"All Validation Errors: {e.errors()}")
    print()

    # Capture each error from ValidationError and append to st.session_state.errors
    for error in e.errors():
    if error['loc']:
    field = error.get('loc', [None])[0] # Get the field where error occurred
    error_message = error.get('msg', 'Unknown error')
    # Store the error in session state
    st.session_state.errors.append({field: error_message})
    else:
    print(f"Missing Mandatory Fields Error: {error['msg']}")
    missing_fields_message = error['msg']
    if "Missing required fields" in missing_fields_message:
    missing_fields = missing_fields_message.split(": ")[-1].split(", ")
    # Add an error for each missing field
    for field in missing_fields:
    # Store each missing field error in the same format as field errors
    st.session_state.errors.append({
    "field": field,
    "error_message": f"Missing required field: {field}"
    })

    print(f"All stored errors in session_state: {st.session_state.errors}")

    # Optional: Log the collected errors
    for err in st.session_state.errors:
    print(f"Collected error: {err}")

    st.session_state.current_stage = 'correct'
    st.rerun()

    if st.session_state.current_stage == "correct":
    print("stage of code: From Validation -> Correct stage")
    # After errors are collected, correct them
    if st.session_state.errors:
    stream_message("The following fields need fixing:\n\n")

    for error in st.session_state.errors:
    stream_message(f"- **{error['field']}**: {error['message']}\n\n")

    # Prompt user to correct all specific errors together
    for error in st.session_state.errors:
    # Ask for input for the field
    stream_message(f"Please provide a valid value for '{error['field']}':\n\n")
    if user_input := st.chat_input(f"Please provide a valid value for '{error['field']}':", key=f"correct_{error['field']}"):
    merged_arguments[error['field']] = user_input
    # st.session_state.chat_history.append({"role": "user", "content": user_input})
    st.session_state.errors.remove[error]
    else:
    st.session_state.current_stage = "validate"
    st.rerun()
    st.stop()

    print()
    print(f"stage of code: Entered Ask_Additional()")
    print(f"Current session_state: {st.session_state.current_stage}")
    if st.session_state.current_stage == "ask_additional":
    stream_message(f"Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize):\n")
    # user_choice = get_user_input("Would you like to 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize): ",
    # key = get_unique_key(f"user_choice_key"), stage = "ask_additional")

    if user_input:= st.chat_input("Enter your choice 1) Add more details, or 2) Correct specific fields? Enter 1 or 2 (or 'n' to finalize)"):
    if user_input == 'n':
    st.session_stage.current_stage = "ask"
    return validated_query
    st.rerun()

    elif st.session_state.user_input == '1':
    stream_message("Please enter additional details:")
    if user_input:= st.chat_input("Please enter additional details:", key = get_unique_key(f"additional_info_key")):
    validated_query = process_chat_response(user_input, tools = tools)
    if(isinstance(validated_query, str)):
    print(f'Yes {validated_query} is string')


    stream_message("Invalid option. Please enter 1, 2, or 'n'.")
    st.stop()



    def process_chat_response(
    user_query: str,
    system_prompt: str = SYSTEM_PROMPT,
    tools: Any = None,
    schema_model: BaseModel = InputQuery,
    language: str = "English"
    ) -> Union[BaseModel, str]:
    """
    Processes the chat response from a completion request and outputs function details and arguments.

    This function sends a list of messages to a chat completion request, processes the response to extract
    function calls and arguments, and prints relevant information. It also merges function arguments and
    initializes an `InputQuery` object using the merged arguments.

    Args:
    user_query (str): Query entered by user
    system_prompt (str): If a user has any specific prompt to enter.
    tools (Any): The tools to be used with the chat completion request.
    schema_model (BaseModel): The Pydantic model class used for validating user query.
    language (str): Language in which you want to a response

    Returns:
    response_output (Union[BaseModel, str]): Returns the response for the query.
    """

    # Ensure chat_history is initialized in session state
    if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

    # Convert chat history into a formatted string
    chat_history_str = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in st.session_state.chat_history])

    # Format the system prompt with the chat history
    formatted_system_prompt = system_prompt.format(chat_history=chat_history_str, language = language)

    messages = [{"role": "system", "content": formatted_system_prompt}]
    messages.append({"role": "user", "content": user_query})

    print(f'Full Prompt: {messages}')

    #print(f'Tools : {tools}')

    response = chat_completion_request(messages, tools=tools, response_format={"type": "text"})

    print(f'\n{response}')

    merged_arguments = defaultdict(lambda: None)

    if response.choices[0].finish_reason == "tool_calls":
    for tool_call in response.choices[0].message.tool_calls:
    function_arguments = json.loads(tool_call.function.arguments)
    merged_arguments.update(function_arguments)

    merged_arguments = dict(merged_arguments)

    print()
    print(f'function call arguments: {function_arguments}')
    print(f"Merged Arguments: {merged_arguments}")

    # Convert merged_arguments to a JSON-like string and escape curly braces
    merged_arguments_str = str(merged_arguments).replace("{", "{{").replace("}", "}}")

    # Append the user's query and the assistant's response to the chat history
    st.session_state.chat_history.append({"role": "user", "content": user_query})
    st.session_state.chat_history.append({"role": "assistant", "content": merged_arguments_str})

    # Verifying the Output with Verifier LLM Agent
    #verifier_response = verifier_agent_response(user_query, merged_arguments, tools)
    print()
    #print(f"Verifier LLM Agent Response: {verifier_response}")

    st.session_state.current_stage = "validate"

    print()
    print(f"stage of code: From process()-> Validate_query()")

    # Validate the InputQuery object with re-prompting if necessary
    final_response = validate_input_query(merged_arguments, schema_model)
    print(f"Process Chat Final Response: {final_response}")

    elif response.choices[0].finish_reason == 'stop' and response.choices[0].message.content is not None:
    final_response = response.choices[0].message.content.strip()

    # Verifying the Output with Verifier LLM Agent
    #verifier_response = verifier_agent_response(user_query, final_response, tools)
    print()
    #print(f"Verifier LLM Agent Response: {verifier_response}")

    # Append the user's query and the assistant's response to the chat history
    st.session_state.chat_history.append({"role": "user", "content": user_query})
    st.session_state.chat_history.append({"role": "assistant", "content": final_response})

    #print(f'\n{final_response}')

    print(f"chat history: {st.session_state.chat_history}")

    return final_response


    As u see, within process_input.py, the validate_input_query() function is where I am using st.rerun() and trying to create st.chat_input() depending upon which st.session_state.current_stage we are in.

    Yet, the moment it sees the first st.rerun() it just stops execution and doesn't execute the code as per the next st.session_state.current_stage. It happens all the time.

    What's going wrong. I have spent days trying to figure out the issue here but to no avail.

    Any guidance is much appreciated.

    Continue reading...

Compartilhe esta Página