diff --git a/src/prometheus/ast.lua b/src/prometheus/ast.lua index 2ca0523a..ef390dd7 100644 --- a/src/prometheus/ast.lua +++ b/src/prometheus/ast.lua @@ -148,11 +148,12 @@ function Ast.NopStatement() } end -function Ast.IfElseExpression(condition, true_value, false_value) +function Ast.IfElseExpression(condition, true_value, elseifs, false_value) return { kind = AstKind.IfElseExpression, condition = condition, true_value = true_value, + elseifs = elseifs, false_value = false_value } end diff --git a/src/prometheus/compiler/expressions.lua b/src/prometheus/compiler/expressions.lua index 877f75d1..f55fb60a 100644 --- a/src/prometheus/compiler/expressions.lua +++ b/src/prometheus/compiler/expressions.lua @@ -28,6 +28,7 @@ handlers[AstKind.AndExpression] = requireExpression("and"); handlers[AstKind.TableConstructorExpression] = requireExpression("table_constructor"); handlers[AstKind.FunctionLiteralExpression] = requireExpression("function_literal"); handlers[AstKind.VarargExpression] = requireExpression("vararg"); +handlers[AstKind.IfElseExpression] = requireExpression("if_else"); -- Binary ops share one handler local binaryHandler = requireExpression("binary"); diff --git a/src/prometheus/compiler/expressions/if_else.lua b/src/prometheus/compiler/expressions/if_else.lua new file mode 100644 index 00000000..317b0c68 --- /dev/null +++ b/src/prometheus/compiler/expressions/if_else.lua @@ -0,0 +1,80 @@ +-- This Script is Part of the Prometheus Obfuscator by levno-710 +-- +-- if_else.lua +-- +-- This Script contains the statement handler for the IfElseExpression. + +local Ast = require("prometheus.ast"); + +return function(self, expression, funcDepth, numReturns) + local scope = self.activeBlock.scope; + local posState = self.registers[self.POS_REGISTER]; + self.registers[self.POS_REGISTER] = self.VAR_REGISTER; + + local regs = {}; + for i = 1, numReturns do + regs[i] = self:allocRegister(); + if i ~= 1 then + self:addStatement(self:setRegister(scope, regs[i], Ast.NilExpression()), {regs[i]}, {}, false); + end + end + + local resReg = regs[1]; + local tmpReg; + + if posState then + tmpReg = self:allocRegister(false); + self:addStatement(self:copyRegisters(scope, {tmpReg}, {self.POS_REGISTER}), {tmpReg}, {self.POS_REGISTER}, false); + end + + local conditionReg = self:compileExpression(expression.condition, funcDepth, 1)[1]; + + local finalBlock = self:createBlock(); + local nextBlock = self:createBlock(); + local innerBlock = self:createBlock(); + + self:addStatement(self:setRegister(scope, self.POS_REGISTER, Ast.OrExpression(Ast.AndExpression(self:register(scope, conditionReg), Ast.NumberExpression(innerBlock.id)), Ast.NumberExpression(nextBlock.id))), {self.POS_REGISTER}, {conditionReg}, false); + self:freeRegister(conditionReg, false); + + self:setActiveBlock(innerBlock); + scope = innerBlock.scope; + + local trueReg = self:compileExpression(expression.true_value, funcDepth, 1)[1]; + self:addStatement(self:copyRegisters(scope, {resReg}, {trueReg}), {resReg}, {trueReg}, false); + self:addStatement(self:setRegister(scope, self.POS_REGISTER, Ast.NumberExpression(finalBlock.id)), {self.POS_REGISTER}, {}, false); + + for _, elif in ipairs(expression.elseifs) do + self:setActiveBlock(nextBlock); + conditionReg = self:compileExpression(elif.condition, funcDepth, 1)[1]; + local elifBlock = self:createBlock(); + nextBlock = self:createBlock(); + local elifScope = self.activeBlock.scope; + + self:addStatement(self:setRegister(elifScope, self.POS_REGISTER, Ast.OrExpression(Ast.AndExpression(self:register(elifScope, conditionReg), Ast.NumberExpression(elifBlock.id)), Ast.NumberExpression(nextBlock.id))), {self.POS_REGISTER}, {conditionReg}, false); + self:freeRegister(conditionReg, false); + + self:setActiveBlock(elifBlock); + elifScope = elifBlock.scope; + local valueReg = self:compileExpression(elif.value, funcDepth, 1)[1]; + self:addStatement(self:copyRegisters(elifScope, {resReg}, {valueReg}), {resReg}, {valueReg}, false); + self:addStatement(self:setRegister(elifScope, self.POS_REGISTER, Ast.NumberExpression(finalBlock.id)), {self.POS_REGISTER}, {}, false); + end + + self:setActiveBlock(nextBlock); + scope = self.activeBlock.scope; + local falseReg = self:compileExpression(expression.false_value, funcDepth, 1)[1]; + self:addStatement(self:copyRegisters(scope, {resReg}, {falseReg}), {resReg}, {falseReg}, false); + self:addStatement(self:setRegister(scope, self.POS_REGISTER, Ast.NumberExpression(finalBlock.id)), {self.POS_REGISTER}, {}, false); + + self.registers[self.POS_REGISTER] = posState; + + self:setActiveBlock(finalBlock); + scope = finalBlock.scope; + + if tmpReg then + self:addStatement(self:copyRegisters(scope, {self.POS_REGISTER}, {tmpReg}), {self.POS_REGISTER}, {tmpReg}, false); + self:freeRegister(tmpReg, false); + end + + return regs; +end \ No newline at end of file diff --git a/src/prometheus/parser.lua b/src/prometheus/parser.lua index 7274bf5e..57ce4dca 100644 --- a/src/prometheus/parser.lua +++ b/src/prometheus/parser.lua @@ -946,10 +946,23 @@ function Parser:expressionLiteral(scope) local condition = self:expression(scope); expect(self, TokenKind.Keyword, "then"); local true_value = self:expression(scope); + + local elseifs = {} + while(consume(self, TokenKind.Keyword, "elseif")) do + local elseif_condition = self:expression(scope); + expect(self, TokenKind.Keyword, "then"); + local elseif_value = self:expression(scope); + + table.insert(elseifs, { + condition = elseif_condition, + value = elseif_value + }); + end + expect(self, TokenKind.Keyword, "else"); local false_value = self:expression(scope); - return Ast.IfElseExpression(condition, true_value, false_value); + return Ast.IfElseExpression(condition, true_value, elseifs, false_value); end end diff --git a/src/prometheus/unparser.lua b/src/prometheus/unparser.lua index 661baf29..0aacd511 100644 --- a/src/prometheus/unparser.lua +++ b/src/prometheus/unparser.lua @@ -688,6 +688,12 @@ function Unparser:unparseExpression(expression, tabbing) push(self:unparseExpression(expression.condition)); push(" then "); push(self:unparseExpression(expression.true_value)); + for _, elseifexp in pairs(expression.elseifs) do + push(" elseif "); + push(self:unparseExpression(elseifexp.condition)); + push(" then "); + push(self:unparseExpression(elseifexp.value)); + end push(" else "); push(self:unparseExpression(expression.false_value)); return joinParts(parts); diff --git a/src/prometheus/visitast.lua b/src/prometheus/visitast.lua index 16dc4fb4..17098859 100644 --- a/src/prometheus/visitast.lua +++ b/src/prometheus/visitast.lua @@ -238,8 +238,12 @@ function visitExpression(expression, previsit, postvisit, data) end if(expression.kind == AstKind.IfElseExpression) then expression.condition = visitExpression(expression.condition, previsit, postvisit, data); - expression.true_expr = visitExpression(expression.true_expr, previsit, postvisit, data); - expression.false_expr = visitExpression(expression.false_expr, previsit, postvisit, data); + expression.true_value = visitExpression(expression.true_value, previsit, postvisit, data); + for i, elif in pairs(expression.elseifs) do + elif.condition = visitExpression(elif.condition, previsit, postvisit, data); + elif.value = visitExpression(elif.value, previsit, postvisit, data); + end + expression.false_value = visitExpression(expression.false_value, previsit, postvisit, data); end if(type(postvisit) == "function") then