[learning deep learning compiler from scratch] XV. Lowering to LLVM IR in learning notes of MLIR Toy Tutorials

Keywords: Pytorch Deep Learning

0x0. Preface

In the previous section, we transferred the Operation lowing of Toy Dialect to after dialect, MemRef Dialect and Standard Dialect, while the toy.print Operation remains unchanged, so it is also called partial lowing. Through this lowing, the lower level implementation logic of Toy Dialect's Operation can be expressed to seek more optimization opportunities and get better MLIR expressions. In this section, we will fully Lowering the hybrid MLIR expression obtained in the previous section to LLVM Dialect, and then generate LLVM IR. We can use the JIT compilation engine of MLIR to run the final MLIR expression and output the calculation results.

0x1. IR drops to LLVM Dialect

In this section, we will introduce how to completely Lowering the MLIR expression at the end of the previous section to LLVM Dialect. Let's review the final MLIR expression in the previous section:

func @main() {
  %cst = arith.constant 1.000000e+00 : f64
  %cst_0 = arith.constant 2.000000e+00 : f64
  %cst_1 = arith.constant 3.000000e+00 : f64
  %cst_2 = arith.constant 4.000000e+00 : f64
  %cst_3 = arith.constant 5.000000e+00 : f64
  %cst_4 = arith.constant 6.000000e+00 : f64

  // Allocating buffers for the inputs and outputs.
  %0 = memref.alloc() : memref<3x2xf64>
  %1 = memref.alloc() : memref<2x3xf64>

  // Initialize the input buffer with the constant values.
  affine.store %cst, %1[0, 0] : memref<2x3xf64>
  affine.store %cst_0, %1[0, 1] : memref<2x3xf64>
  affine.store %cst_1, %1[0, 2] : memref<2x3xf64>
  affine.store %cst_2, %1[1, 0] : memref<2x3xf64>
  affine.store %cst_3, %1[1, 1] : memref<2x3xf64>
  affine.store %cst_4, %1[1, 2] : memref<2x3xf64>

  affine.for %arg0 = 0 to 3 {
    affine.for %arg1 = 0 to 2 {
      // Load the transpose value from the input buffer.
      %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64>

      // Multiply and store into the output buffer.
      %3 = arith.mulf %2, %2 : f64
      affine.store %3, %0[%arg0, %arg1] : memref<3x2xf64>
    }
  }

  // Print the value held by the buffer.
  toy.print %0 : memref<3x2xf64>
  memref.dealloc %1 : memref<2x3xf64>
  memref.dealloc %0 : memref<3x2xf64>
  return
}

We want to completely Lowering the MLIR expression of the three kinds of dialects to LLVM Dialect. Note that LLVM Dialect is an intermediate representation of a special Dialect hierarchy of MLIR, which is not LLVM IR. The overall process of lowing LLVM Dialect can be divided into the following steps:

1. Lowering toy.print Operation

The toy.print operation was not Lowering in the previous lowing, so the toy.print operation is preferentially Lowering here. We nest toy.print lowing into a non affine loop that calls printf for each element. The dialog transformation framework supports passing lowing, and does not need to directly lowing to llvm dialog. Through application delivery, Lowering can apply a variety of modes to legalize the operation (legalization means complete Lowering to LLVM Dialect here). Passing lowing here is reflected in lowing toy.print into the loop nested dialog first, rather than directly lowing into llvm dialog.

In the lowing process, the declaration of printf is in mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp, and the code is as follows:

	/// Return a symbol reference to the printf function, inserting it into the
  /// module if necessary.
  static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
                                             ModuleOp module) {
    auto *context = module.getContext();
    if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
      return SymbolRefAttr::get(context, "printf");

    // Create a function declaration for printf, the signature is:
    //   * `i32 (i8*, ...)`
    auto llvmI32Ty = IntegerType::get(context, 32);
    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
                                                  /*isVarArg=*/true);

    // Insert the printf function into the body of the parent module.
    PatternRewriter::InsertionGuard insertGuard(rewriter);
    rewriter.setInsertionPointToStart(module.getBody());
    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
    return SymbolRefAttr::get(context, "printf");
  }

This part of the code returns the symbolic reference of the printf function and inserts it into the Module if necessary. In the function, create a function declaration for printf, and then insert the printf function into the body of the parent Module.

2. Identify all components required for the lowing process

The first thing that needs to be determined is the conversion target. For this lowing, except for the top-level Module, we Lowering everything as LLVM Dialect. There are some discrepancies between the information expressed in the code and the official documents, and the latest code shall prevail.

// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
LLVMConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();

Then we need to determine the type converter, our existing MLIR expression and MemRef type, which we need to convert to the type of LLVM. To perform this transformation, we use TypeConverter as part of Lowering. This converter specifies how one type maps to another. Since there are no Toy Dialect operations in the existing operations, the requirements can be met by using the default converter of MLIR. It is defined as follows:

// During this lowering, we will also be lowering the MemRef types, that are
  // currently being operated on, to a representation in LLVM. To perform this
  // conversion we use a TypeConverter as part of the lowering. This converter
  // details how one type maps to another. This is necessary now that we will be
  // doing more complicated lowerings, involving loop region arguments.
  LLVMTypeConverter typeConverter(&getContext());

Then you need to determine the conversion patterns. This part of the code is:

// Now that the conversion target has been defined, we need to provide the
  // patterns used for lowering. At this point of the compilation process, we
  // have a combination of `toy`, `affine`, and `std` operations. Luckily, there
  // are already exists a set of patterns to transform `affine` and `std`
  // dialects. These patterns lowering in multiple stages, relying on transitive
  // lowerings. Transitive lowering, or A->B->C lowering, is when multiple
  // patterns must be applied to fully transform an illegal operation into a
  // set of legal ones.
  RewritePatternSet patterns(&getContext());
  populateAffineToStdConversionPatterns(patterns);
  populateLoopToStdConversionPatterns(patterns);
  populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
  populateStdToLLVMConversionPatterns(typeConverter, patterns);

  // The only remaining operation to lower from the `toy` dialect, is the
  // PrintOp.
  patterns.add<PrintOpLowering>(&getContext());

The above code shows how to define matching rewrite rules for after dialect, Standard Dialect and legacy toy.print. First, drop the after dialect to Standard Dialect, that is, populateAffineToStdConversionPatterns. Then drop the loop (for the toy.print operation, which has been lowered to the loop nested dialog) to the standard dialog, that is, populateLoopToStdConversionPatterns. Finally, convert Standard Dialect to LLVM Dialect, that is, populatemreftollvmconversionpatterns. And don't forget to add the lowing mode PrintOpLowering of toy.print to the patterns.

3. Fully Lowering

Once you have defined all the components required for the lowing process, you can perform full lowing. Using the applyfullconversion (module, target, STD:: move (patterns)) function can ensure that there are only legal operations for the conversion result. In the previous part of Lowering's note, mlir::applyPartialConversion(function, target, patterns) is called, which can be compared.

// We want to completely lower to LLVM, so we use a `FullConversion`. This
  // ensures that only legal operations will remain after the conversion.
  auto module = getOperation();
  if (failed(applyFullConversion(module, target, std::move(patterns))))
    signalPassFailure();

4. Add the fully Lowering Pass defined above to Pipline

This code is in mlir/examples/toy/Ch6/toyc.cpp:

if (isLoweringToLLVM) {
  // Finish lowering the toy IR to the LLVM dialect.
  pm.addPass(mlir::toy::createLowerToLLVMPass());
 }

This code adds mlir::toy::createLowerToLLVMPass() as a fully Lowering Pass to the optimized pipeline, which can reduce the MLIR expression to an LLVM Dialect expression. Let's run the sample program to see the results:

Execute the following command:

cd llvm-project/build/bin
./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/llvm-lowering.mlir -emit=mlir-llvm

That is, the MLIR expression after complete lowing is obtained, and the result is relatively long. Only part of it is shown here. You can see that the current MLIR expression is completely in the LLVM Dialect space.

llvm.func @free(!llvm<"i8*">)
llvm.func @printf(!llvm<"i8*">, ...) -> i32
llvm.func @malloc(i64) -> !llvm<"i8*">
llvm.func @main() {
  %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64
  %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64

  ...

^bb16:
  %221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %222 = llvm.mlir.constant(0 : index) : i64
  %223 = llvm.mlir.constant(2 : index) : i64
  %224 = llvm.mul %214, %223 : i64
  %225 = llvm.add %222, %224 : i64
  %226 = llvm.mlir.constant(1 : index) : i64
  %227 = llvm.mul %219, %226 : i64
  %228 = llvm.add %225, %227 : i64
  %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*">
  %230 = llvm.load %229 : !llvm<"double*">
  %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32
  %232 = llvm.add %219, %218 : i64
  llvm.br ^bb15(%232 : i64)

  ...

^bb18:
  %235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%236) : (!llvm<"i8*">) -> ()
  %237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%238) : (!llvm<"i8*">) -> ()
  %239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
  %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
  llvm.call @free(%240) : (!llvm<"i8*">) -> ()
  llvm.return
}

0x2. Code generation and Jit execution

We can use the JIT compilation engine to run the LLVM Dialect IR obtained above to obtain reasoning results. Here we use the mlir::ExecutionEngine infrastructure to run LLVM Dialect IR. The program is located at: mlir/examples/toy/Ch6/toyc.cpp.

int runJit(mlir::ModuleOp module) {
  // Initialize LLVM targets.
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmPrinter();

  // Register the translation from MLIR to LLVM IR, which must happen before we
  // can JIT-compile.
  mlir::registerLLVMDialectTranslation(*module->getContext());

  // An optimization pipeline to use within the execution engine.
  auto optPipeline = mlir::makeOptimizingTransformer(
      /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
      /*targetMachine=*/nullptr);

  // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
  // the module.
  auto maybeEngine = mlir::ExecutionEngine::create(
      module, /*llvmModuleBuilder=*/nullptr, optPipeline);
  assert(maybeEngine && "failed to construct an execution engine");
  auto &engine = maybeEngine.get();

  // Invoke the JIT-compiled function.
  auto invocationResult = engine->invokePacked("main");
  if (invocationResult) {
    llvm::errs() << "JIT invocation failed\n";
    return -1;
  }

  return 0;
}

Note this line in particular: MLIR:: registerllvmdialecttranslation (* module - > getcontext());. From the comments of the code, this is to translate the LLVM Dialect expression into LLVM IR, which plays a cache role during JIT compilation, that is, the above various MLIR expression transformations will not be repeated during the next execution.

Here, create an MLIR execution engine mlir::ExecutionEngine to run the main function in the expression. You can use the following command to output the final calculation results:

cd llvm-project/build/bin
./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/codegen.toy -emit=jit -opt

The result is:

1.000000 16.000000 
4.000000 25.000000 
9.000000 36.000000

Here, we optimized the original MLIR expression through a series of passes, partially Lowering to three kinds of Dialect mixed expressions, and fully Lowering to LLVM Dialect expression. Finally, we translated it to LLVM IR and executed it using MLIR's Jit execution engine to obtain the final result.

In addition, a dumpLLVMIR function is provided in mlir/examples/toy/Ch6/toyc.cpp, which can translate an MLIR expression into an LLVM IR expression. Then it is optimized by LLVM IR. The LLVM IR that was born can be printed using the following command:

$ cd llvm-project/build/bin
$ ./toyc-ch6 ../../mlir/test/Examples/Toy/Ch6/codegen.toy -emit=llvm -opt

0x3. Summary

This article describes how to fully lower some of the MLIR expressions after lowing to LLVM Dialect, and then execute the code and obtain the reasoning results through the JIT compilation engine. In addition, you can output the LLVM IR generated by LLVM Dialect.

Posted by WesPear on Wed, 17 Nov 2021 08:19:48 -0800