Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h (revision 56688) +++ include/llvm/LinkAllPasses.h (working copy) @@ -64,6 +64,7 @@ (void) llvm::createFunctionInliningPass(); (void) llvm::createAlwaysInlinerPass(); (void) llvm::createFunctionProfilerPass(); + (void) llvm::createFunctionRetValuesPass(); (void) llvm::createGlobalDCEPass(); (void) llvm::createGlobalOptimizerPass(); (void) llvm::createGlobalsModRefPass(); Index: include/llvm/Transforms/IPO.h =================================================================== --- include/llvm/Transforms/IPO.h (revision 56688) +++ include/llvm/Transforms/IPO.h (working copy) @@ -151,6 +151,11 @@ ModulePass *createIPConstantPropagationPass(); //===----------------------------------------------------------------------===// +/// createFunctionRetValuesPass() - This pass infers function return values +/// +ModulePass *createFunctionRetValuesPass(); + +//===----------------------------------------------------------------------===// /// createIPSCCPPass - This pass propagates constants from call sites into the /// bodies of functions, and keeps track of whether basic blocks are executable /// in the process. Index: lib/Transforms/IPO/FunctionRetValues.cpp =================================================================== --- lib/Transforms/IPO/FunctionRetValues.cpp (revision 0) +++ lib/Transforms/IPO/FunctionRetValues.cpp (revision 0) @@ -0,0 +1,299 @@ +//===---- FunctionRetValues.cpp - Remove unused function declarations -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass loops over all of the functions and produces a summary of the +// possible return values. Then it tries to prune dead code based on the inferred +// return values. +// It's a good idea to run a constant progator before because this pass doesnt do it +// After, it's also a good idea to run DCE, because this pass can expose some dead code +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "function-ret-values" +#include "llvm/Transforms/IPO.h" +#include "llvm/Instructions.h" +#include "llvm/Constants.h" +#include "llvm/Value.h" +#include "llvm/Pass.h" +#include "llvm/Module.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Compiler.h" +using namespace llvm; + +STATISTIC(NumInferRetValues, "Number of inferred function return values"); + +namespace { + +/// @brief Pass to infer function return values. +class VISIBILITY_HIDDEN FunctionRetValuesPass : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + FunctionRetValuesPass() : ModulePass(&ID) { } + virtual bool runOnModule(Module &M); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesCFG(); + } +}; + +} // end anonymous namespace + +char FunctionRetValuesPass::ID = 0; +static RegisterPass +X("function-ret-values", "Function return values inferer"); + +typedef SmallPtrSet Set; +typedef DenseMap Map; + + +static bool GetPossibleValues(Value *val, Map &map, Set &set) { + if (Constant *C = dyn_cast(val)) { + set.insert(C); + return true; + + } else if (SelectInst *s = dyn_cast(val)) { + return GetPossibleValues(s->getTrueValue(), map, set) && + GetPossibleValues(s->getFalseValue(), map, set); + + } else if (CallInst *c = dyn_cast(val)) { + Map::iterator it = map.find(c->getCalledFunction()); + if (it == map.end()) return false; + + Set& s = it->second; + if (s.empty()) return false; + + set.insert(s.begin(), s.end()); + return true; + } + + if (Instruction *I = dyn_cast(val)) + fprintf(stderr, "GetPossibleValues(): unhandled instruction '%s'\n", I->getOpcodeName()); + else + fprintf(stderr, "GetPossibleValues(): unhandled value type: %u\n", val->getValueID()); + + return false; +} + + +static void AnalyzeFunction(Function *F, Map &map) { + // If this function could be overridden later in the link stage, we can't + // propagate information about its results into callers. + if (F->hasLinkOnceLinkage() || F->hasWeakLinkage()) + return; + + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + Instruction *I = BB->getTerminator(); + + if (ReturnInst *RI = dyn_cast(I)) { + if (RI->getNumOperands() != 1) + return; // multiple return values not supported + + if (!GetPossibleValues(RI->getReturnValue(), map, map[F])) { + map[F].clear(); + return; + } + } + } +} + + +/// remove dead references from PHI nodes of BasicBlock BB +static void fixPHIs(BasicBlock *BB, BasicBlock *src) +{ + /*for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (PHINode *phi = dyn_cast(I)) { + phi->removeIncomingValue(src); + } + }*/ +} + + +static bool ConstantEqualsAny(ConstantInt *C2, Set &s) { + for (Set::iterator it = s.begin(), end = s.end(); it != end; ++it) { + ConstantInt *C = cast(*it); + + if (C->getValue() == C2->getValue()) + return true; + } + + return false; +} + + +#define X(n,f) \ +static bool n(ConstantInt *C2, Set &s) { \ + for (Set::iterator it = s.begin(), end = s.end(); it != end; ++it) { \ + ConstantInt *C = cast(*it); \ + if (!C->getValue().f(C2->getValue())) \ + return false; \ + } \ + return true; \ +} + +X(ConstantUGTAll, ugt) +X(ConstantUGEAll, uge) +X(ConstantULTAll, ult) +X(ConstantULEAll, ule) +X(ConstantSGTAll, sgt) +X(ConstantSGEAll, sge) +X(ConstantSLTAll, slt) +X(ConstantSLEAll, sle) + +#undef X + +static bool MakeConstant(Instruction *I, Instruction *op, Set &set) { + + // turn icmp(x, y) into false if all possible values for x arent equal to y + if (ICmpInst *cmp = dyn_cast(I)) { + bool changed = false; + + // make the following code simpler by having a known operand order + if (cmp->getOperand(0) != op) { + cmp->swapOperands(); + changed = true; + } + + ConstantInt *C = dyn_cast(cmp->getOperand(1)); + if (!C) return changed; + + int result; + +#define X(v,f,r) case v: \ + if (f(C, set)) return changed; \ + result = r; \ + break; + +#define Y(v,f1,f2) case v: \ + if (f1(C, set)) result = 1; \ + if (f2(C, set)) result = 0; \ + else return changed; \ + break; + + switch (cmp->getPredicate()) { + X(CmpInst::ICMP_EQ, ConstantEqualsAny, 0) + X(CmpInst::ICMP_NE, !ConstantEqualsAny, 1) + Y(CmpInst::ICMP_UGT, ConstantUGTAll, ConstantULEAll) + Y(CmpInst::ICMP_UGE, ConstantUGEAll, ConstantULTAll) + Y(CmpInst::ICMP_ULT, ConstantULTAll, ConstantUGEAll) + Y(CmpInst::ICMP_ULE, ConstantULEAll, ConstantUGTAll) + Y(CmpInst::ICMP_SGT, ConstantSGTAll, ConstantSLEAll) + Y(CmpInst::ICMP_SGE, ConstantSGEAll, ConstantSLTAll) + Y(CmpInst::ICMP_SLT, ConstantSLTAll, ConstantSGEAll) + Y(CmpInst::ICMP_SLE, ConstantSLEAll, ConstantSGTAll) + + default: + assert(0 && "unhandled ICmpInst predicate"); + return changed; + } + +#undef X +#undef Y + + ++NumInferRetValues; + I->replaceAllUsesWith(llvm::ConstantInt::get(APInt(1, result))); + return true; + + } else if (SelectInst *s = dyn_cast(I)) { + + + // TODO + + + } else if (SwitchInst *s = dyn_cast(I)) { + assert(s->getCondition() == op); + bool changed = false; + + for (unsigned i = 1; i < s->getNumCases(); ++i) { + if (!ConstantEqualsAny(s->getCaseValue(i), set)) { + fixPHIs(s->getSuccessor(i), s->getParent()); + s->removeCase(i--); + ++NumInferRetValues; + changed = true; + } + } + + // this means that the cases are all covered and thus we can replace the default + // destination with one of the cases' target + if (set.size() == (s->getNumCases()-1)) { + unsigned idx = s->getNumCases()-1; // it's faster to remove the last one + fixPHIs(s->getSuccessor(0), s->getParent()); + s->setSuccessor(0, s->getSuccessor(idx)); + s->removeCase(idx); + ++NumInferRetValues; + changed = true; + } + + return changed; + } + + if (isa(I) || + isa(I) || + isa(I) || + isa(I) + ) + return false; // we can't do anything with the instructions above + + fprintf(stderr, "MakeConstant(): unhandled instruction: '%s'\n", I->getOpcodeName()); + + return false; +} + + +static bool TransformFunction(Function *F, Map &map) { + bool MadeChange = false; + + for (Value::use_iterator I = F->use_begin(), E = F->use_end(); I != E; ++I) { + if (CallInst* CI = dyn_cast(I)) { + Function *called = CI->getCalledFunction(); + if (!called || I->hasNUses(0)) continue; + + Map::iterator it = map.find(called); + if (it == map.end()) continue; + + Set& s = it->second; + if (s.empty()) continue; + + if (s.size() == 1) { // the function only returns one value + NumInferRetValues += I->getNumUses(); + I->replaceAllUsesWith(*s.begin()); + MadeChange = true; + + } else { + // iterate over all uses and try to make them constant + for (Value::use_iterator it = I->use_begin(), end = I->use_end(); it != end; ++it) { + MadeChange |= MakeConstant(cast(*it), CI, s); + } + } + } + } + + return MadeChange; +} + + +bool FunctionRetValuesPass::runOnModule(Module &M) { + bool MadeChange = false; + Map map; + + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + AnalyzeFunction(I, map); + } + + for (Map::iterator I = map.begin(), E = map.end(); I != E; ++I) { + MadeChange |= TransformFunction(I->first, map); + } + + return MadeChange; +} + +ModulePass *llvm::createFunctionRetValuesPass() { + return new FunctionRetValuesPass(); +}