Skip to main content

Creating a Retrieval-Augmented Generation Pipeline for Ollama

00:08:53:39

Introduction

With the introduction of Large Language Models (LLMs), the general public have always assumed that they work on magic and dreams. However, training copious amounts of data has become a difficulty, both based on a software and hardware perspective. This is done through two methods: fine-tuning a LLM, or Retrieval-Augmented Generation (RAG).

In this article, I will show you how to create a fairly quick, yet robust RAG, and how we can train data for a large language model. I will be using a multitude of LLMs that are provided by Ollama, which will be how we will run our LLMs. I will also be doing all of this on a M1 pro Macbook Pro 16 inch.

Retrieval-Augmented Generation

Retrieval-Augmented Generation (RAG) is the concept of providing LLMs with additional information that is provided from some external source, outside of the data that is trained from LLMs. While this sounds like the concept of fine-tuning a LLM, they are specifically different in the sense that fine-tuning helps make LLMs be more viable in specific situations, RAG focuses on specifically connecting external data and sources to existing LLMs utilizing retrieval methods that are too advanced for the lowly software engineer that I am.

Also, something to note is that either using RAG or fine-tuning is completely situational. I will break down additional edge cases in another article, but overall, since we want

This is how we can make sure our LLMs are up to date and prevents hallucinations. Most people would assume that hallucinations can be circumvented through fine-tuning the LLM, what we actually want to use is RAG. This allows the LLM to learn from the external sources that I mentioned above, and that is a great way of removing the all too prominent hallucinations everyone has been commenting on.

In the end, RAG is just a concept that helps us make our LLMs better, but it won't solve every issue LLMs have. We want to pair this with other concepts, like fine-tuning, to ensure our models are performing well. More importantly, understanding the differences between RAG and fine-tuning will ultimately draw lines on when to use these concepts.

Why am I Talking About This?

Besides being extremely cool, there's a problem that actually spawned this article.

Finding and reading documentation has been a struggle when starting new positions. Sometimes documentation might be great, but the only thing that is holding it back is terrible search engines. Thus, I want to feed a LLM data from the existing documentation for my company's internal documentation so that we can have a faster and better way for new hires to onboard, with a custom LLM that knows the documentation and where to find it.

Also, a lot of people also don't really understand when to use RAG over fine tuning, so breaking down and showing an example from scratch will hopefully help my friends stop bugging me for explanations on these concepts.

Getting Started

I will be using the following tools to implement our RAG pipeline:

  • langchain/langchain-community - This will be how we ingest and implement our RAG
  • Ollama - How we will be running our LLM
    • We will be using llama2 for now. It might be better if we used codellama but weighing different LLMs could be a topic for another day.
  • streamlit/streamlit-chat - I'm too lazy to write the front-end since I'm burnt out on UI quirks, so for now, we will be handling inputs from the user using streamlit for the frontend, so we will be able to have a chat-style user interface with minimal effort.

And that's it! Woah, those dependencies seem pretty straightforward, right?

If you just want to see my RAG implementation, here is the Github Repository, which is a pretty quick MVP. I won't post my actual code, since it uses internal APIs for web scraping, but it will allow you to feed a website url to the RAG pipeline, and the LLM will learn from that external data source.

Running Ollama

Once you have Ollama set up on your device, you need to open up your terminal of choice and grab your desired model.

Note that there are many LLMs available, I'm going to be using mistral, so the following command will be:

shell
ollama pull llama2

If you're running this on your own PC, be weary of the number of parameters. I have a 16 gb of ram, which is just fine for the 7B parameters for llama2, but it can accept up to 70B. Make sure you know what you're doing and read the documentation.

Setting up rag.py

This file will set up our RAG implementaiton. Take note that we will probably want to create an actual data pipeline that scrapes something like Confluence, or whatever your company stores your documetation on, but in this case, we will just be creating a single document upload pipeline and class.

First, we want to create our RAG pipeline. We should import the following packages:

python
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.document_loaders import WebBaseLoader
from langchain.chains import RetrievalQA
from langchain.schema.output_parser import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import PromptTemplate
from langchain.vectorstores.utils import filter_complex_metadata

This might be intimidating for people who are new to langchain, but let's break down all these components:

  • Chroma - ChromaDB Vectore Store. This will store our vector data we will be extracting from our inputted website
  • ChatOllama - LangChain's core chat model. We will be using Ollama, so we import ChatOllama
  • GPT4AllEmbeddings - Create a new model by parsing and validating input data from keyword arguments
  • WebBaseLoader - The website loader that will extract data from the website
  • RetrievalQA - Initializes the actual RAG Chain. This takes the query the user will provide and retrieves the relevant data from the index that gets stored inside our vector_store
  • StrOutputParser - Parses the LLM request to grab data that likely answers the users' questions.
  • RecursiveCharacterTextSplitter - A text splitter that splits based on a specific chunk size
  • RunnablePassthrough - Passes through inputs without modifying the data
  • PromptTemplate - A prompt template for LLMs to ingest
  • filter_complex_metadata - Filters' out metadata types not supported for vector store in langchain

Our RAG Implementation

We will be creating a new class to help keep track of our functions and data. This also helps create mocks for testing for the future. Our class will have a init function and ingest function to keep track three things: The vector store, Retriever for the chain, and the chain itself.

Initial Class Loadout

Our initial class is described as follows:

python
class RAGImplementation:
    vector_store = None
    retriever = None
    chain = None

    def __init__(self):
        self.model = ChatOllama(model="llama2")
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=10, chunk_overlap=8)
        self.prompt = PromptTemplate.from_template(
            """
            <s> [INST] You are an assistant for question-answering tasks.
            Use the following context to answer the question. If you don't
            know the answer, just say you don't know. Use three sentences
            and be concise in your answer. [/INST] </s>
            [INST] Question: {question}
            Context: {context}
            Answer: [/INST]
            """
        )

In our dunder init() function, we are going to set up our model we will be using through ChatOllama, and setting our model to be llama2.

The text splitter will be how we will be splitting up our dataset and defining our chunk size. Right now I'm going to be specifying our chunk_size to be only slightly larger than our chunk_overlap, since we don't have a ton of data to train on. 80% should do.

Lastly, we have generated a prompt template, that is some default template I pulled from the internet that will help us extract more semantic meaning. I'm not a prompt engineer, so I'll just believe it does its job for now.

Ingesting data

Now we need a way to ingest data from the url that the user provides.

python
    def ingest(self, url: str):
        loader = WebBaseLoader(url)
        docs = loader.load()
        chunks = self.text_splitter.split_documents(docs)
        chunks = filter_complex_metadata(chunks)

        vector_store = Chroma.from_documents(
            documents=chunks, embedding=GPT4AllEmbeddings())
        self.retriever = vector_store.as_retriever(
            search_type="similarity_score_threshold",
            search_kwargs={
                "k": 3,
                "score_threshold": 0.5,
            },
        )

        self.chain = RetrievalQA.from_chain_type(self.model, retriever=vector_store.as_retriever())

We will then store that data into a vector store, using the GPT4AllEmbeddings() function that we imported. After that, that we keep all that data in LangChain, so that our data can be retrieved by the model later on, when the user queries that data.

Lastly, we want to create an ask method, that just takes in the actual query the user provides itself.

python
    def ask(self, query: str):
        if not self.chain:
            return "Please, add a URL first."
        invocation = self.chain.invoke(query)
        print(invocation)
        return invocation['result']

It's pretty straightforward, it just invokes the chain with the query that was provided. It injects the query into the prompt template, and returns the results of the chain.

And that's all for rag.py! Now for the last part, we need to create the UI, so the user can input their data, and retrieve the results, using streamlit.

Our app.py

I'm not going to be explaining the UI much, since it's not really relevant, but I wanted to include it either way.

python
import streamlit as st
from streamlit_chat import message
from rag import RAGImplementation
import re

st.set_page_config(page_title="Confluence Semantic Search MVP")


def display_messages():
    st.subheader("Chat")
    for i, (msg, is_user) in enumerate(st.session_state["messages"]):
        message(msg, is_user=is_user, key=str(i))
    st.session_state["thinking_spinner"] = st.empty()


def process_input():
    if st.session_state["user_input"] and len(st.session_state['user_input'].strip()) > 0:
        user_text = st.session_state["user_input"].strip()
        with st.session_state["thinking_spinner"], st.spinner(f"Thinking"):
            agent_text = st.session_state["assistant"].ask(user_text)

        st.session_state["messages"].append((user_text, True))
        st.session_state["messages"].append((agent_text, False))

def process_url():
    if st.session_state['url_input']:
        with st.session_state['thinking_spinner'], st.spinner(f'Fetching Website...'):
            st.session_state['assistant'].ingest(st.session_state['url_input'])
    else:
        print('url was invalid')


def page():
    if len(st.session_state) == 0:
        st.session_state["messages"] = []
        st.session_state["assistant"] = RAGImplementation()

    st.header("Confluence Semantic Search MVP")

    st.subheader("Upload a document")
    st.text_input('URL', key='url_input', on_change=process_url)
    st.session_state["ingestion_spinner"] = st.empty()

    display_messages()
    st.text_input("Message", key="user_input", on_change=process_input)


if __name__ == "__main__":
    page()

Some things to note, however, are how we are storing the actual data. Streamlit allows us to store our assistant into a session, to help keep everything succinct, and allows the entire page and their functions to grab our assistant function.

We can then run our app.py file normally.

Conclusion

RAG is a powerful tool to help LLMs learn from external data sources. Keep in mind that this currently doesn't retain the context of what's happing, so there's still work to be done in the future. Maybe I'll periodically update the page with additional information, but for now, I'll leave it at that.

If you made it this far, kudos to you, and I hope you learned something new today. If you have any questions, feel free to reach out to me!