SeExpr
Evaluator.h
Go to the documentation of this file.
1/*
2 Copyright Disney Enterprises, Inc. All rights reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License
6 and the following modification to it: Section 6 Trademarks.
7 deleted and replaced with:
8
9 6. Trademarks. This License does not grant permission to use the
10 trade names, trademarks, service marks, or product names of the
11 Licensor and its affiliates, except as required for reproducing
12 the content of the NOTICE file.
13
14 You may obtain a copy of the License at
15 http://www.apache.org/licenses/LICENSE-2.0
16*/
17
18#include "ExprConfig.h"
19#include "ExprLLVMAll.h"
20#include "VarBlock.h"
21
22#ifdef SEEXPR_ENABLE_LLVM
23#include <llvm/Config/llvm-config.h>
24#include <llvm/Support/Compiler.h>
25#endif
26
27extern "C" void SeExpr2LLVMEvalFPVarRef(SeExpr2::ExprVarRef *seVR, double *result);
28extern "C" void SeExpr2LLVMEvalStrVarRef(SeExpr2::ExprVarRef *seVR, double *result);
29extern "C" void SeExpr2LLVMEvalCustomFunction(int *opDataArg,
30 double *fpArg,
31 char **strArg,
32 void **funcdata,
33 const SeExpr2::ExprFuncNode *node);
34
35namespace SeExpr2 {
36#ifdef SEEXPR_ENABLE_LLVM
37
38LLVM_VALUE promoteToDim(LLVM_VALUE val, unsigned dim, llvm::IRBuilder<> &Builder);
39
40class LLVMEvaluator {
41 // TODO: this seems needlessly complex, let's fix it
42 // TODO: let the dev code allocate memory?
43 // FP is the native function for this expression.
44 template <class T>
45 class LLVMEvaluationContext {
46 private:
47 typedef void (*FunctionPtr)(T *, char **, uint32_t);
48 typedef void (*FunctionPtrMultiple)(char **, uint32_t, uint32_t, uint32_t);
49 FunctionPtr functionPtr;
50 FunctionPtrMultiple functionPtrMultiple;
51 T *resultData;
52
53 public:
54 LLVMEvaluationContext(const LLVMEvaluationContext &) = delete;
55 LLVMEvaluationContext &operator=(const LLVMEvaluationContext &) = delete;
56 ~LLVMEvaluationContext() { delete[] resultData; }
57 LLVMEvaluationContext() : functionPtr(nullptr), resultData(nullptr) {}
58 void init(void *fp, void *fpLoop, int dim) {
59 reset();
60 functionPtr = reinterpret_cast<FunctionPtr>(fp);
61 functionPtrMultiple = reinterpret_cast<FunctionPtrMultiple>(fpLoop);
62 resultData = new T[dim];
63 }
64 void reset() {
65 if (resultData) delete[] resultData;
66 functionPtr = nullptr;
67 resultData = nullptr;
68 }
69 const T *operator()(VarBlock *varBlock) {
70 assert(functionPtr && resultData);
71 functionPtr(resultData, varBlock ? varBlock->data() : nullptr, varBlock ? varBlock->indirectIndex : 0);
72 return resultData;
73 }
74 void operator()(VarBlock *varBlock, size_t outputVarBlockOffset, size_t rangeStart, size_t rangeEnd) {
75 assert(functionPtr && resultData);
76 functionPtrMultiple(varBlock ? varBlock->data() : nullptr, outputVarBlockOffset, rangeStart, rangeEnd);
77 }
78 };
79 std::unique_ptr<LLVMEvaluationContext<double>> _llvmEvalFP;
80 std::unique_ptr<LLVMEvaluationContext<char *>> _llvmEvalStr;
81
82 std::unique_ptr<llvm::LLVMContext> _llvmContext;
83 std::unique_ptr<llvm::ExecutionEngine> TheExecutionEngine;
84
85 public:
86 LLVMEvaluator() {}
87
88 const char *evalStr(VarBlock *varBlock) { return *(*_llvmEvalStr)(varBlock); }
89 const double *evalFP(VarBlock *varBlock) { return (*_llvmEvalFP)(varBlock); }
90
91 void evalMultiple(VarBlock *varBlock, uint32_t outputVarBlockOffset, uint32_t rangeStart, uint32_t rangeEnd) {
92 return (*_llvmEvalFP)(varBlock, outputVarBlockOffset, rangeStart, rangeEnd);
93 }
94
95 void debugPrint() {
96 // TheModule->print(llvm::errs(), nullptr);
97 }
98
99 bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType) {
100 using namespace llvm;
101 InitializeNativeTarget();
102 InitializeNativeTargetAsmPrinter();
103 InitializeNativeTargetAsmParser();
104
105 std::string uniqueName = getUniqueName();
106
107 // create Module
108 _llvmContext.reset(new LLVMContext());
109
110 std::unique_ptr<Module> TheModule(new Module(uniqueName + "_module", *_llvmContext));
111
112 // create all needed types
113 Type *i8PtrTy = Type::getInt8PtrTy(*_llvmContext); // char *
114 PointerType *i8PtrPtrTy = PointerType::getUnqual(i8PtrTy); // char **
115 PointerType *i8PtrPtrPtrTy = PointerType::getUnqual(i8PtrPtrTy); // char ***
116 Type *i32Ty = Type::getInt32Ty(*_llvmContext); // int
117 Type *i32PtrTy = Type::getInt32PtrTy(*_llvmContext); // int *
118 Type *i64Ty = Type::getInt64Ty(*_llvmContext); // int64 *
119 Type *doublePtrTy = Type::getDoublePtrTy(*_llvmContext); // double *
120 PointerType *doublePtrPtrTy = PointerType::getUnqual(doublePtrTy); // double **
121 Type *voidTy = Type::getVoidTy(*_llvmContext); // void
122
123 // create bindings to helper functions for variables and fucntions
124 Function *SeExpr2LLVMEvalCustomFunctionFunc = nullptr;
125 Function *SeExpr2LLVMEvalFPVarRefFunc = nullptr;
126 Function *SeExpr2LLVMEvalStrVarRefFunc = nullptr;
127 Function *SeExpr2LLVMEvalstrlenFunc = nullptr;
128 Function *SeExpr2LLVMEvalmallocFunc = nullptr;
129 Function *SeExpr2LLVMEvalfreeFunc = nullptr;
130 Function *SeExpr2LLVMEvalmemsetFunc = nullptr;
131 Function *SeExpr2LLVMEvalstrcatFunc = nullptr;
132 {
133 {
134 FunctionType *FT = FunctionType::get(voidTy, {i32PtrTy, doublePtrTy, i8PtrPtrTy, i8PtrPtrTy, i64Ty}, false);
135 SeExpr2LLVMEvalCustomFunctionFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalCustomFunction", TheModule.get());
136 }
137 {
138 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, doublePtrTy}, false);
139 SeExpr2LLVMEvalFPVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalFPVarRef", TheModule.get());
140 }
141 {
142 FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i8PtrPtrTy}, false);
143 SeExpr2LLVMEvalStrVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalStrVarRef", TheModule.get());
144 }
145 {
146 FunctionType *FT = FunctionType::get(i32Ty, { i8PtrTy }, false);
147 SeExpr2LLVMEvalstrlenFunc = Function::Create(FT, Function::ExternalLinkage, "strlen", TheModule.get());
148 }
149 {
150 FunctionType *FT = FunctionType::get(i8PtrTy, { i32Ty }, false);
151 SeExpr2LLVMEvalmallocFunc = Function::Create(FT, Function::ExternalLinkage, "malloc", TheModule.get());
152 }
153 {
154 FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy }, false);
155 SeExpr2LLVMEvalfreeFunc = Function::Create(FT, Function::ExternalLinkage, "free", TheModule.get());
156 }
157 {
158 FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy, i32Ty, i32Ty }, false);
159 SeExpr2LLVMEvalmemsetFunc = Function::Create(FT, Function::ExternalLinkage, "memset", TheModule.get());
160 }
161 {
162 FunctionType *FT = FunctionType::get(i8PtrTy, { i8PtrTy, i8PtrTy }, false);
163 SeExpr2LLVMEvalstrcatFunc = Function::Create(FT, Function::ExternalLinkage, "strcat", TheModule.get());
164 }
165 }
166
167 // create function and entry BB
168 bool desireFP = desiredReturnType.isFP();
169 Type *ParamTys[] = {
170 desireFP ? doublePtrTy : i8PtrPtrTy,
171 doublePtrPtrTy,
172 i32Ty
173 };
174 FunctionType *FT = FunctionType::get(voidTy, ParamTys, false);
175 Function *F = Function::Create(FT, Function::ExternalLinkage, uniqueName + "_func", TheModule.get());
176#if LLVM_VERSION_MAJOR > 4
177 F->addAttribute(llvm::AttributeList::FunctionIndex, llvm::Attribute::AlwaysInline);
178#else
179 F->addAttribute(llvm::AttributeSet::FunctionIndex, llvm::Attribute::AlwaysInline);
180#endif
181 {
182 // label the function with names
183 const char *names[] = {"outputPointer", "dataBlock", "indirectIndex"};
184 int idx = 0;
185 for (auto &arg : F->args()) arg.setName(names[idx++]);
186 }
187
188 unsigned int dimDesired = (unsigned)desiredReturnType.dim();
189 unsigned int dimGenerated = parseTree->type().dim();
190 {
191 BasicBlock *BB = BasicBlock::Create(*_llvmContext, "entry", F);
192 IRBuilder<> Builder(BB);
193
194 // codegen
195 Value *lastVal = parseTree->codegen(Builder);
196
197 // return values through parameter.
198 Value *firstArg = &*F->arg_begin();
199 if (desireFP) {
200 if (dimGenerated > 1) {
201 Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
202 assert(newLastVal->getType()->getVectorNumElements() >= dimDesired);
203 for (unsigned i = 0; i < dimDesired; ++i) {
204 Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
205 Value *val = Builder.CreateExtractElement(newLastVal, idx);
206 Value *ptr = Builder.CreateInBoundsGEP(firstArg, idx);
207 Builder.CreateStore(val, ptr);
208 }
209 } else if (dimGenerated == 1) {
210 for (unsigned i = 0; i < dimDesired; ++i) {
211 Value *ptr = Builder.CreateConstInBoundsGEP1_32(nullptr, firstArg, i);
212 Builder.CreateStore(lastVal, ptr);
213 }
214 } else {
215 assert(false && "error. dim of FP is less than 1.");
216 }
217 } else {
218 Builder.CreateStore(lastVal, firstArg);
219 }
220
221 Builder.CreateRetVoid();
222 }
223
224 // write a new function
225 FunctionType *FTLOOP = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty, i32Ty}, false);
226 Function *FLOOP = Function::Create(FTLOOP, Function::ExternalLinkage, uniqueName + "_loopfunc", TheModule.get());
227 {
228 // label the function with names
229 const char *names[] = {"dataBlock", "outputVarBlockOffset", "rangeStart", "rangeEnd"};
230 int idx = 0;
231 for (auto &arg : FLOOP->args()) {
232 arg.setName(names[idx++]);
233 }
234 }
235 {
236 // Local variables
237 Value *dimValue = ConstantInt::get(i32Ty, dimDesired);
238 Value *oneValue = ConstantInt::get(i32Ty, 1);
239
240 // Basic blocks
241 BasicBlock *entryBlock = BasicBlock::Create(*_llvmContext, "entry", FLOOP);
242 BasicBlock *loopCmpBlock = BasicBlock::Create(*_llvmContext, "loopCmp", FLOOP);
243 BasicBlock *loopRepeatBlock = BasicBlock::Create(*_llvmContext, "loopRepeat", FLOOP);
244 BasicBlock *loopIncBlock = BasicBlock::Create(*_llvmContext, "loopInc", FLOOP);
245 BasicBlock *loopEndBlock = BasicBlock::Create(*_llvmContext, "loopEnd", FLOOP);
246 IRBuilder<> Builder(entryBlock);
247 Builder.SetInsertPoint(entryBlock);
248
249 // Get arguments
250 Function::arg_iterator argIterator = FLOOP->arg_begin();
251 Value *varBlockCharPtrPtrArg = &*argIterator; ++argIterator;
252 Value *outputVarBlockOffsetArg = &*argIterator; ++argIterator;
253 Value *rangeStartArg = &*argIterator; ++argIterator;
254 Value *rangeEndArg = &*argIterator; ++argIterator;
255
256 // Allocate Variables
257 Value *rangeStartVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeStartVar");
258 Value *rangeEndVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeEndVar");
259 Value *indexVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "indexVar");
260 Value *outputVarBlockOffsetVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "outputVarBlockOffsetVar");
261 Value *varBlockDoublePtrPtrVar = Builder.CreateAlloca(doublePtrPtrTy, oneValue, "varBlockDoublePtrPtrVar");
262 Value *varBlockTPtrPtrVar = Builder.CreateAlloca(desireFP == true ? doublePtrPtrTy : i8PtrPtrPtrTy, oneValue, "varBlockTPtrPtrVar");
263
264 // Copy variables from args
265 Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, doublePtrPtrTy, "varBlockAsDoublePtrPtr"), varBlockDoublePtrPtrVar);
266 Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, desireFP ? doublePtrPtrTy : i8PtrPtrPtrTy, "varBlockAsTPtrPtr"), varBlockTPtrPtrVar);
267 Builder.CreateStore(rangeStartArg, rangeStartVar);
268 Builder.CreateStore(rangeEndArg, rangeEndVar);
269 Builder.CreateStore(outputVarBlockOffsetArg, outputVarBlockOffsetVar);
270
271 // Set output pointer
272 Value *outputBasePtrPtr = Builder.CreateGEP(nullptr, Builder.CreateLoad(varBlockTPtrPtrVar), outputVarBlockOffsetArg, "outputBasePtrPtr");
273 Value *outputBasePtr = Builder.CreateLoad(outputBasePtrPtr, "outputBasePtr");
274 Builder.CreateStore(Builder.CreateLoad(rangeStartVar), indexVar);
275
276 Builder.CreateBr(loopCmpBlock);
277 Builder.SetInsertPoint(loopCmpBlock);
278 Value *cond = Builder.CreateICmpULT(Builder.CreateLoad(indexVar), Builder.CreateLoad(rangeEndVar));
279 Builder.CreateCondBr(cond, loopRepeatBlock, loopEndBlock);
280
281 Builder.SetInsertPoint(loopRepeatBlock);
282 Value *myOutputPtr = Builder.CreateGEP(nullptr, outputBasePtr, Builder.CreateMul(dimValue, Builder.CreateLoad(indexVar)));
283 Builder.CreateCall(F, {myOutputPtr, Builder.CreateLoad(varBlockDoublePtrPtrVar), Builder.CreateLoad(indexVar)});
284
285 Builder.CreateBr(loopIncBlock);
286
287 Builder.SetInsertPoint(loopIncBlock);
288 Builder.CreateStore(Builder.CreateAdd(Builder.CreateLoad(indexVar), oneValue), indexVar);
289 Builder.CreateBr(loopCmpBlock);
290
291 Builder.SetInsertPoint(loopEndBlock);
292 Builder.CreateRetVoid();
293 }
294
296 #ifdef DEBUG
297 std::cerr << "Pre verified LLVM byte code " << std::endl;
298 TheModule->print(llvm::errs(), nullptr);
299 #endif
300 }
301
302 // TODO: Find out if there is a new way to veirfy
303 // if (verifyModule(*TheModule)) {
304 // std::cerr << "Logic error in code generation of LLVM alert developers" << std::endl;
305 // TheModule->print(llvm::errs(), nullptr);
306 // }
307 Module *altModule = TheModule.get();
308 std::string ErrStr;
309 TheExecutionEngine.reset(EngineBuilder(std::move(TheModule))
310 .setErrorStr(&ErrStr)
311 // .setUseMCJIT(true)
312 .setOptLevel(CodeGenOpt::Aggressive)
313 .create());
314
315 altModule->setDataLayout(TheExecutionEngine->getDataLayout());
316
317 // Add bindings to C linkage helper functions
318 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalFPVarRefFunc, (void *)SeExpr2LLVMEvalFPVarRef);
319 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalStrVarRefFunc, (void *)SeExpr2LLVMEvalStrVarRef);
320 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalCustomFunctionFunc, (void *)SeExpr2LLVMEvalCustomFunction);
321 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrlenFunc, (void *)strlen);
322 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrcatFunc, (void *)strcat);
323 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmemsetFunc, (void *)memset);
324 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmallocFunc, (void *)malloc);
325 TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalfreeFunc, (void *)free);
326
327 // [verify]
328 std::string errorStr;
329 llvm::raw_string_ostream raw(errorStr);
330 if (llvm::verifyModule(*altModule, &raw)) {
331 parseTree->addError(raw.str());
332 return false;
333 }
334
335 // Setup optimization
336 llvm::PassManagerBuilder builder;
337 std::unique_ptr<llvm::legacy::PassManager> pm(new llvm::legacy::PassManager);
338 std::unique_ptr<llvm::legacy::FunctionPassManager> fpm(new llvm::legacy::FunctionPassManager(altModule));
339 builder.OptLevel = 3;
340#if (LLVM_VERSION_MAJOR >= 4)
341 builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
342#else
343 builder.Inliner = llvm::createAlwaysInlinerPass();
344#endif
345 builder.populateModulePassManager(*pm);
346 // fpm->add(new llvm::DataLayoutPass());
347 builder.populateFunctionPassManager(*fpm);
348 fpm->run(*F);
349 fpm->run(*FLOOP);
350 pm->run(*altModule);
351
352 // Create the JIT. This takes ownership of the module.
353
354 if (!TheExecutionEngine) {
355 fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
356 exit(1);
357 }
358
359 TheExecutionEngine->finalizeObject();
360 void *fp = TheExecutionEngine->getPointerToFunction(F);
361 void *fpLoop = TheExecutionEngine->getPointerToFunction(FLOOP);
362 if (desireFP) {
363 _llvmEvalFP.reset(new LLVMEvaluationContext<double>);
364 _llvmEvalFP->init(fp, fpLoop, dimDesired);
365 } else {
366 _llvmEvalStr.reset(new LLVMEvaluationContext<char *>);
367 _llvmEvalStr->init(fp, fpLoop, dimDesired);
368 }
369
371 #ifdef DEBUG
372 std::cerr << "Pre verified LLVM byte code " << std::endl;
373 altModule->print(llvm::errs(), nullptr);
374 #endif
375 }
376
377 return true;
378 }
379
380 std::string getUniqueName() const {
381 std::ostringstream o;
382 o << std::setbase(16) << (uint64_t)(this);
383 return ("_" + o.str());
384 }
385};
386
387#else // no LLVM support
389 public:
390 void unsupported() { throw std::runtime_error("LLVM is not enabled in build"); }
391 const char *evalStr(VarBlock *varBlock) {
392 unsupported();
393 return "";
394 }
395 const double *evalFP(VarBlock *varBlock) {
396 unsupported();
397 return 0;
398 }
399 bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType) {
400 unsupported();
401 return false;
402 }
403 void evalMultiple(VarBlock *varBlock, int outputVarBlockOffset, size_t rangeStart, size_t rangeEnd) {
404 unsupported();
405 }
406 void debugPrint() {}
407};
408#endif
409
410} // end namespace SeExpr2
void SeExpr2LLVMEvalFPVarRef(SeExpr2::ExprVarRef *seVR, double *result)
void SeExpr2LLVMEvalCustomFunction(int *opDataArg, double *fpArg, char **strArg, void **funcdata, const SeExpr2::ExprFuncNode *node)
void SeExpr2LLVMEvalStrVarRef(SeExpr2::ExprVarRef *seVR, double *result)
double LLVM_VALUE
Definition ExprLLVM.h:33
Node that calls a function.
Definition ExprNode.h:517
abstract class for implementing variable references
Definition Expression.h:45
static bool debugging
Whether to debug expressions.
Definition Expression.h:86
const char * evalStr(VarBlock *varBlock)
Definition Evaluator.h:391
void evalMultiple(VarBlock *varBlock, int outputVarBlockOffset, size_t rangeStart, size_t rangeEnd)
Definition Evaluator.h:403
bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType)
Definition Evaluator.h:399
const double * evalFP(VarBlock *varBlock)
Definition Evaluator.h:395
A thread local evaluation context. Just allocate and fill in with data.
Definition VarBlock.h:33