SiFive - October 10, 2024

LLM Optimization and Deployment on SiFive RISC-V Intelligence Products

Bruce Lai, Darren Hsieh and Hong-Rong Hsu

Large language models (LLMs) have become essential to numerous applications, thanks to their powerful capabilities in natural language understanding and text generation. However, their large model sizes and high computational demands present significant challenges for efficient deployment and real-time performance.

Can you imagine running Llama on a RISC-V machine while achieving real-time performance? The ML compiler plays a crucial role in making this possible. For more details, please refer to our previous article “ML Compiler for RISC-V Vector”. In this article, we will share how we optimized and deployed Llama on the SiFive Intelligence X390 platform[ref0], the future computing platform for AI/ML workloads.

**Want to learn more about this? Join our webinar on October 16th:

Webinar information**

SiFive AI/ML Software Stack

To enable LLM models on RISC-V, multiple components need to interact seamlessly. In this section, we will introduce each component of the SiFive AI/ML Software Stack.

The orange blocks in Figure 1 represent the fundamental RISC-V building blocks, primarily owned and maintained by SiFive:

  • SiFive Intelligence and High-Performance core series: Cores well known as X280/X390 (Intelligence) and P470/P670/P870 (High-Performance).

  • SiFive accelerators: In-house Hardware to accelerate domain specific instructions like SiFive XM Series[ref9].

  • SiFive LLVM Compiler: Provides RISC-V C/C++ compilation and optimization, RVV intrinsic programming, auto-vectorization, and efficient RISC-V backend code generation for MLIR. SiFive offers a proprietary version of LLVM with advanced u-architecture optimizations and custom instruction compilation/IR for SiFive cores, though users can also access a generic version from the upstream compiler directly.

  • SiFive System Software: Provide FSFL (Freedom SDK for Linux) - the Yocto/OpenEmbedded based RISC-V Linux solution and FSFM (Freedom SDK for Metal) - a reference ASM/C/C++ bare-metal environment for SiFive cores.

  • SKL(SiFive Kernel Library): A fully optimized C/C++ library with a set of tuned routines that maximize algorithm throughput on SiFive Processors. Some hot-spot operations in IREE will be offloaded to SKL in order to maximize performance from SiFive processors.

SiFive AI/ML Software Stack Figure 1: SiFive AI/ML Software Stack

The blue blocks in Figure 1 represent the components that SiFive leverages from and contributes back to Open Source projects. ML Compiler and Runtime: IREE[ref1] is an open sourced MLIR-based compiler and runtime. We leverage most of the generic optimization while adding SiFive architecture specific functions and optimizations.

VCIX MLIR Dialect: SiFive has open-sourced the VCIX MLIR dialect, allowing users to lower their models and delegate them to custom TPUs with minimum effort.

ML Interpreters: For customers requiring a more lightweight framework, we offer:

  • Customized TFLite with RVV optimizations.
  • Upstream XNNPACK with RVV optimizations.
  • ONNXRuntime and delegated to XNNPACK.
  • Customized Llama.cpp with additional RVV optimizations
  • Other Open Source Libraries like libyuv, libc, openSSL, zlib-ng to accelerate other non-AI/ML domains in RVV and fine-tuned for SiFive u-architecture.

The yellow blocks in Figure 1 represent the components provided by third-party vendors or communities, which SiFive leverages to create a comprehensive solution.

Pytorch end-to-end flow

In this section, we’ll introduce the implementation of a Pytorch end-to-end flow for Large Language Models (LLM) using SHARK-Turbine, a tool supported by AMD for facilitating various model deployment activities. The complete workflow for enabling the LLM demonstration is illustrated in Figure 2.

Pytorch LLM Demo Figure 2: Pytorch end-to-end Large Language Model Demo flow in SHARK-Turbine

SHARK-Turbine[ref2] provides two key scripts to enable the LLM demonstration. The first script, stateless_llm.py, is responsible for converting a PyTorch Hugging Face LLM into a VMFB (VM FlatBuffer) and a parameter file. This conversion process utilizes the FX/Dynamo-based torch-mlir compiler along with the IREE compiler. Currently we are using upstream IREE along with several SiFive patches, which incorporates SiFive MLIR optimizations and leverages SiFive LLVM to achieve optimal performance. Detailed optimization strategies are discussed in the subsequent section.

After the LLM is compiled, the second script, llm_runner.py, utilizes Python bindings to invoke IREE runtime, load VMFB models, and set up the runtime context. This allows the LLM to run within a Python environment on SiFive Intelligence cores. Users can input queries (prompts), and the SiFive Intelligence platform will execute multiple inferences to generate meaningful responses.

Now, you might be wondering: Why not simply use llama.cpp for LLM inference, since the GGUF format (a file format for storing models for inference with llama.cpp) works fine? We will explain this in the performance section.

Llama Optimization

Model Architecture and Bottleneck Understanding model architecture helps us identify the performance bottleneck. LLaMA (Large Language Model Meta AI) is a series of transformer-based[ref3] large language models developed by Meta. Figure 3 shows the overview of LLaMA architecture. This architecture utilizes self-attention mechanisms and feed-forward neural networks to process input sequences. The key components of the architecture include:

  • Embedding Layer: Converts tokens into dense vectors.
  • Multi-Head Self-Attention Mechanism: Computes relationships between all pairs of tokens, enabling the model to capture long-range dependencies.
  • Feed-Forward Networks (FFNs): Contains a series of matrix multiplication(matmul) operations and non-linear functions to provide the non-Linearity and feature Transformation.
  • Layer Normalization: Normalizes activations for stability and faster convergence.
  • Residual Connections: Avoid gradient vanishment during backpropagation.

Overview of LLaMA architecture Figure 3: Overview of the LLaMA architecture[ref4]

The Self-Attention and Feed-Forward layers are implemented through a series of matrix multiplication operations.

The figure 3 shows that there are N self-attention and N feed-forward networks, both implemented through a series of matrix multiplication operations. This extensive reliance on matrix multiplication makes it the primary bottleneck in transformer-based models.

Performance Results

Performance profiling result for TinyLLama using IREE during the decode phase. Different operations can be grouped into a single dispatch. For example, matrix multiplication (matmul) followed by bias addition can be combined into one dispatch. The "Operation Type" column indicates the primary operation within each dispatch. Notably, matmul operations are the primary performance bottleneck, consuming over 95% of the inference time. “D” in the shape means dynamic shape, which handles the increasing dimension from the KV-cache.

Table 1 Table 1

Table 1 shows the performance profiling results of running TinyLLama with IREE during the decode phase. The next section will provide more detailed information about the "prefill" and "generation" phases. Briefly, in the prefill phase, the matrix multiplication (matmul) operations follow the General Matrix Multiply (GEMM) pattern, where the M dimension is greater than or equal to 1. In contrast, during the decode phase, the matmul operations follow the General Matrix-Vector Multiply (GEMV) pattern, where the M dimension is always equal to 1.The shape for the matmul operation is represented as [BxMxNxK], where: B is batch size and it’s optional; M is rows in output; N is columns in output; K is reduction dimension. For the case of dispatch_60, the batch size is 32(from the number of attention heads) and D indicates dynamic input in the N dimension.

Table 1 demonstrates that matmul operations account for over 95% of the decode phase inference time, making them the primary performance bottleneck in LLaMA.

Since matmul operations are the main performance hotspot in LLM inference, the next section will focus on optimizing these operations.

Optimization through IREE Compiler

IREE (Intermediate Representation Execution Environment) is an MLIR-based end-to-end AI/ML compiler and runtime. The Architecture Overview is shown in figure4. In IREE, the input model is lowered to MLIR and then different levels of optimizations are applied (such as kernel fusion, tiling, and loop unrolling) and finally translated to target-dependent VM Bytecode. The VM Bytecode is able to execute with IREE runtime.

IREE Architecture Figure 4. IREE Architecture Overview[ref1]

Although IREE is a powerful framework that supports both CPU & GPU code generation and optimizations, RISC-V backend with RVV vectorization hasn’t been tuned to optimal at this point. Figure 5 shows the significant performance improvement between the code before and after optimizations.

TinyLLama on SiFive-X390 Figure 5. Performance Gap between before and after optimizations

Several optimizations will be demonstrated in the following subsection to improve the performance of the matmul operation, leading to significant improvements in LLM performance. Cache & Register Tiling optimizations for matmul In LLM inference, the process is divided into two distinct phases: Prefill (Prompt) and Decode (Generation). The LLM inference process is visualized in Figure 6.

Prefill Prompt Figure 6. Prefill (Prompt) and Decode(Generation) phases in the LLM inference.[ref5]

Prefill or Prompt Phase: During this phase, the input prompt has a length greater than or equal to 1. The model processes the initial input sequence and constructs the KV-Cache, which will be reused in subsequent decoding steps. Assuming the matmul has a problem size of [m,n,k] where the left-hand side (LHS) is [m,k], the dimension m is always greater than or equal to 1.

Decode or Generation Phase: This phase begins with the output of the prefill phase as the first input. The model uses the KV-Cache to efficiently generate tokens in the following iterations. The matmul operations in this phase have the dimension m=1. This smaller m value (1) is a key distinction compared to traditional operations in neural network models.

To optimize the matmul operation using RISC-V Vector (RVV), two key factors affect the efficiency of matmul:

Register Tiling: This focuses on maximizing the utilization of the CPU's vector registers during matrix multiplication. By carefully organizing data to fit into vector registers, register tiling helps minimize memory access and maximize computational throughput.

Cache Tiling: This strategy optimizes matrix multiplication by improving data locality, ensuring that data remains in the cache hierarchy for as long as possible. Efficient cache tiling reduces memory latency and improves the performance of the matmul operation by minimizing cache misses.

In IREE, the matmul operation is implemented using an outer product approach. The output register tile size is [m, n], where m and n represent the number of elements in the output rows and columns, respectively. Implementing the outer product with RVV requires m vector grouped registers for output accumulators and 1 additional vector grouped register to load data from the RHS matrix. Each vector grouped register has a length of n. Hence, the vector register utilization rate can be calculated using the formula:

FORMULA 1

For instance, if VLEN=1024, the data type is float, and [m,n]= [8,32]

FORMULA INSTANCE

Upon analyzing the output assembly code, we discovered that IREE generates code with a fixed output register tile size of [m,n]=[8,32] on RISC-V platforms. This configuration reduces the vector register utilization rate to 28% when VLEN = 1024. The optimal case occurs when VLEN = 512, where the utilization rate improves to 56%. However, for smaller VLEN values, such as VLEN = 128, the generated code suffers from excessive vector register spilling, leading to poor performance.

To enhance performance, we adjusted the register tile size to [m, n] = [7, m4], where m4 corresponds to LMUL = 4. In this case, if VLEN = 1024 and the data type is float32, the value of n will be 128, which is the maximum vector length when LMUL = 4. This adjustment allows us to achieve a 100% utilization rate of the vector registers, maximizing performance.

The register tile configuration [m,n]=[7,m4] is well-suited for the prefill phase of LLM inference, where the matrix dimension M typically exceeds 7. However, this setup is suboptimal for the decode phase. During decoding, we modified the register tile size to [m,n]=[1,m8], which significantly improved the vector register utilization rate from 25% to 50%.

Following the adjustment to the register tiling policy for enhanced performance, we extended our optimization to the cache tiling strategy within the IREE compiler. The cache tiling size is now dynamically selected based on the register tiling size n. For instance, when n=128, IREE defaults to an n-dimension cache tile size of 128. By fine-tuning this parameter to match the specific requirements of the system, we are able to further enhance cache efficiency, yielding notable performance improvements.

Using the IREE compiler, it is fast and convenient to tune the tiling policy at each level of the memory hierarchy. Compared to library-based solutions, IREE significantly reduces the required engineering effort.

Performance Results

Originally, generating a single token on a 32.5 MHz FPGA took several hours. After applying IREE optimizations, the f16 TinyLlama model now generates a token in about 5 seconds. Normalized to a 1 GHz real chip frequency, this translates to 5.37 tokens per second, achieving real-time user experience on single-core X390.

We also compared IREE with Llama.cpp, a pure C/C++ LLM inference project. While llama.cpp is convenient for deployment, its performance optimizations are limited. Even with hand-written RVV kernels, it is constrained by algorithm limitations and lacks graph-level optimization.

Table 2 Table 2: Performance (tokens/second) of TinyLlama-1.1b on X390 between llama.cpp and IREE

The data also shows that the SiFive X390 is competitive, as a single core can achieve real-time user experience performance on TinyLlama.

Not only that, leveraging the optimizations we have made allows us to successfully run Llama2-7b on the SiFive Intelligence platform, as the model architectures of TinyLlama and Llama2-7b are the same. The e2e flow and scripts can be leveraged as well.

Table 3 Table 3: Llama2-7b-Q4 performance comparison

Accuracy Verification

After optimization, securing accuracy is also important for deployment. This section will describe how we verify the model accuracy.

Leveraging MLPerf's Llama Accuracy Flow for TinyLlama

To align with industry standards, we leverage MLPerf’s benchmarks to evaluate the accuracy of our models, ensuring our solutions are both reliable and effective. Currently, our focus is solely on accuracy—we’ve yet to leverage the performance benchmarking aspects of MLPerf.

While MLPerf picked the Llama2 70B model as a benchmark, we adapt it by using TinyLlama, which provides a more manageable, scaled-down version of the original model.

Adapting MLPerf’s Llama Accuracy Flow for TinyLlama

Our approach revolves around using MLPerf's Llama accuracy flow to ensure that the TinyLlama model maintains accuracy while adapting IREE runtime on SiFive Intelligence Cores. This flow provides a robust framework to evaluate how well our models perform on OPENORCA dataset. With the Llama accuracy script, we continuously assess the accuracy of TinyLlama, allowing us to refine the model and make necessary adjustments to maximize precision within the constraints of our hardware.

How LoadGen interacts with system Figure 7. How the LoadGen interacts with the Inference system.[ref7]

  1. Benchmark knows the model, dataset, and preprocessing.
  2. Benchmark hands dataset sample IDs to LoadGen.
  3. LoadGen starts generating queries of sample IDs.
  4. Benchmark creates requests to the backend.
  5. Result is post-processed and forwarded to LoadGen.
  6. LoadGen outputs logs for analysis.
  7. Accuracy script uses LoadGen Logs and reference responses from the dataset to calculate the ROUGE scores.
  8. Accuracy script outputs the ROUGE scores.

“Result not verified by MLCommons Association.” Followed by MLPerf Policies[ref8]

Table 4 Table 4 : MLPerf OPENORCA Accuracy Comparison of TinyLlama-1.1b Across Different Frameworks and Platforms

The MLPerf taskforce selected ROUGE scores to evaluate how closely a generated text aligns with its reference. They decided to use three specific ROUGE metrics for this benchmark: ROUGE-1, ROUGE-2, and ROUGE-L.

Limited by FPGA emulation speed, we executed only 100 samples to validate accuracy. For the FP32 model, the results indicate that the SiFive X390 with IREE achieved identical scores compared with the NVIDIA 2080 Ti and HuggingFace backend. In contrast, for the FP16 model, the X390 with IREE demonstrated slightly better accuracy than the FP32. It is important to note that the full dataset in MLPerf OpenOrca consists of 24,576 samples, and reference results from NVIDIA and Hugging Face show a decline in ROUGE scores after processing the entire dataset.

Unverified MLPerf® v4.0 Inference Closed Llama2 offline. Result not verified by MLCommons Association. The MLPerf name and logo are registered and unregistered trademarks of MLCommons Association in the United States and other countries. All rights reserved. Unauthorized use strictly prohibited. See www.mlcommons.org for more information.”

Demo Snapshot

The following snapshot was taken from a demo running on the X390 FPGA. The operating system is the Linux kernel from the SiFive FSFL (Freedom SDK for Linux). The prerequisites include pre-compiled VMFB, model weights (in .safetensors format), the IREE runtime, the llm_runner.py script, and the IREE Python binding package.

Demo Snapshot Figure 7. Demo snapshot on X390 FPGA

Conclusion and Future Work

In this article, we walked through the end-to-end process of deploying TinyLlama and Llama2-7b-Q4, from PyTorch to the SiFive Intelligence X390 platform, demonstrating the readiness of running LLMs on the cutting-edge RISC-V ecosystem. We also highlighted how using IREE for optimization can achieve significantly better performance compared to library-based frameworks. Accuracy validation post-optimization is equally important, which is why we have integrated our software stack into the MLPerf framework to ensure accuracy. We plan to upstream the generic optimizations to the IREE repository, so stay tuned if you're interested in our work.

Exciting tasks we're currently working on include:

  • For the upcoming SiFive XM series platform[ref9], we are enabling end-to-end lowering and optimization for SiFve Intelligence X390 cores and the AI matrix engine via IREE.
  • Supporting the installation and native execution of PyTorch on RISC-V platforms.
  • Enabling Llama 3.2 on SiFive platforms. We look forward to sharing more in our next technical blog!

Webinar

SiFive is hosting a live webinar series titled Advanced LLM Optimization and Deployment on RISC-V: Techniques You Need to Know on Wednesday, October 16, 2024.

You’ll learn about:

  • SiFive AI/ML Software Stack for RISC-V
  • End-to-end deployment of Pytorch Llama models
  • Challenges and solutions in optimizing LLM models
  • Achieving real-time Llama performance with MLIR-based IREE

English Session: October 16, 2024 8am PDT | 11am EDT | 5pm CET Register

Chinese Session: October 16, 2024 10am CST Taiwan China

Reference

SiFive Intelligence X390
IREE
SHARK-Turbine
Attention Is All You Need
LLama architecture
LLM Inference Serving: Survey of Recent Advances and Opportunities
MLPerf Llama2 70B
MLPerf loadgen
MLPerf Policies
SiFive XM series

https://cloud-v.co/blog/risc-v-1/accelerating-llama-cpp-with-risc-v-vector-extension-3
PerfXLM: https://www.youtube.com/watch?v=tVXejqZCL_c