Skip to content
Open
3 changes: 2 additions & 1 deletion src/prometheus/ast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ function Ast.NopStatement()
}
end

function Ast.IfElseExpression(condition, true_value, false_value)
function Ast.IfElseExpression(condition, true_value, elseifs, false_value)
return {
kind = AstKind.IfElseExpression,
condition = condition,
true_value = true_value,
elseifs = elseifs,
false_value = false_value
}
end
Expand Down
1 change: 1 addition & 0 deletions src/prometheus/compiler/expressions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ handlers[AstKind.AndExpression] = requireExpression("and");
handlers[AstKind.TableConstructorExpression] = requireExpression("table_constructor");
handlers[AstKind.FunctionLiteralExpression] = requireExpression("function_literal");
handlers[AstKind.VarargExpression] = requireExpression("vararg");
handlers[AstKind.IfElseExpression] = requireExpression("if_else");

-- Binary ops share one handler
local binaryHandler = requireExpression("binary");
Expand Down
80 changes: 80 additions & 0 deletions src/prometheus/compiler/expressions/if_else.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
-- This Script is Part of the Prometheus Obfuscator by levno-710
--
-- if_else.lua
--
-- This Script contains the statement handler for the IfElseExpression.

local Ast = require("prometheus.ast");

return function(self, expression, funcDepth, numReturns)
local scope = self.activeBlock.scope;
local posState = self.registers[self.POS_REGISTER];
self.registers[self.POS_REGISTER] = self.VAR_REGISTER;

local regs = {};
for i = 1, numReturns do
regs[i] = self:allocRegister();
if i ~= 1 then
self:addStatement(self:setRegister(scope, regs[i], Ast.NilExpression()), {regs[i]}, {}, false);
end
end

local resReg = regs[1];
local tmpReg;

if posState then
tmpReg = self:allocRegister(false);
self:addStatement(self:copyRegisters(scope, {tmpReg}, {self.POS_REGISTER}), {tmpReg}, {self.POS_REGISTER}, false);
end

local conditionReg = self:compileExpression(expression.condition, funcDepth, 1)[1];

local finalBlock = self:createBlock();
local nextBlock = self:createBlock();
local innerBlock = self:createBlock();

self:addStatement(self:setRegister(scope, self.POS_REGISTER, Ast.OrExpression(Ast.AndExpression(self:register(scope, conditionReg), Ast.NumberExpression(innerBlock.id)), Ast.NumberExpression(nextBlock.id))), {self.POS_REGISTER}, {conditionReg}, false);
self:freeRegister(conditionReg, false);

self:setActiveBlock(innerBlock);
scope = innerBlock.scope;

local trueReg = self:compileExpression(expression.true_value, funcDepth, 1)[1];
self:addStatement(self:copyRegisters(scope, {resReg}, {trueReg}), {resReg}, {trueReg}, false);
self:addStatement(self:setRegister(scope, self.POS_REGISTER, Ast.NumberExpression(finalBlock.id)), {self.POS_REGISTER}, {}, false);

for _, elif in ipairs(expression.elseifs) do
self:setActiveBlock(nextBlock);
conditionReg = self:compileExpression(elif.condition, funcDepth, 1)[1];
local elifBlock = self:createBlock();
nextBlock = self:createBlock();
local elifScope = self.activeBlock.scope;

self:addStatement(self:setRegister(elifScope, self.POS_REGISTER, Ast.OrExpression(Ast.AndExpression(self:register(elifScope, conditionReg), Ast.NumberExpression(elifBlock.id)), Ast.NumberExpression(nextBlock.id))), {self.POS_REGISTER}, {conditionReg}, false);
self:freeRegister(conditionReg, false);

self:setActiveBlock(elifBlock);
elifScope = elifBlock.scope;
local valueReg = self:compileExpression(elif.value, funcDepth, 1)[1];
self:addStatement(self:copyRegisters(elifScope, {resReg}, {valueReg}), {resReg}, {valueReg}, false);
self:addStatement(self:setRegister(elifScope, self.POS_REGISTER, Ast.NumberExpression(finalBlock.id)), {self.POS_REGISTER}, {}, false);
end

self:setActiveBlock(nextBlock);
scope = self.activeBlock.scope;
local falseReg = self:compileExpression(expression.false_value, funcDepth, 1)[1];
self:addStatement(self:copyRegisters(scope, {resReg}, {falseReg}), {resReg}, {falseReg}, false);
self:addStatement(self:setRegister(scope, self.POS_REGISTER, Ast.NumberExpression(finalBlock.id)), {self.POS_REGISTER}, {}, false);

self.registers[self.POS_REGISTER] = posState;

self:setActiveBlock(finalBlock);
scope = finalBlock.scope;

if tmpReg then
self:addStatement(self:copyRegisters(scope, {self.POS_REGISTER}, {tmpReg}), {self.POS_REGISTER}, {tmpReg}, false);
self:freeRegister(tmpReg, false);
end

return regs;
end
15 changes: 14 additions & 1 deletion src/prometheus/parser.lua
Original file line number Diff line number Diff line change
Expand Up @@ -946,10 +946,23 @@ function Parser:expressionLiteral(scope)
local condition = self:expression(scope);
expect(self, TokenKind.Keyword, "then");
local true_value = self:expression(scope);

local elseifs = {}
while(consume(self, TokenKind.Keyword, "elseif")) do
local elseif_condition = self:expression(scope);
expect(self, TokenKind.Keyword, "then");
local elseif_value = self:expression(scope);

table.insert(elseifs, {
condition = elseif_condition,
value = elseif_value
});
end

expect(self, TokenKind.Keyword, "else");
local false_value = self:expression(scope);

return Ast.IfElseExpression(condition, true_value, false_value);
return Ast.IfElseExpression(condition, true_value, elseifs, false_value);
end
end

Expand Down
6 changes: 6 additions & 0 deletions src/prometheus/unparser.lua
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,12 @@ function Unparser:unparseExpression(expression, tabbing)
push(self:unparseExpression(expression.condition));
push(" then ");
push(self:unparseExpression(expression.true_value));
for _, elseifexp in pairs(expression.elseifs) do
push(" elseif ");
push(self:unparseExpression(elseifexp.condition));
push(" then ");
push(self:unparseExpression(elseifexp.value));
end
push(" else ");
push(self:unparseExpression(expression.false_value));
return joinParts(parts);
Expand Down
8 changes: 6 additions & 2 deletions src/prometheus/visitast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,12 @@ function visitExpression(expression, previsit, postvisit, data)
end
if(expression.kind == AstKind.IfElseExpression) then
expression.condition = visitExpression(expression.condition, previsit, postvisit, data);
expression.true_expr = visitExpression(expression.true_expr, previsit, postvisit, data);
expression.false_expr = visitExpression(expression.false_expr, previsit, postvisit, data);
expression.true_value = visitExpression(expression.true_value, previsit, postvisit, data);
for i, elif in pairs(expression.elseifs) do
elif.condition = visitExpression(elif.condition, previsit, postvisit, data);
elif.value = visitExpression(elif.value, previsit, postvisit, data);
end
expression.false_value = visitExpression(expression.false_value, previsit, postvisit, data);
end

if(type(postvisit) == "function") then
Expand Down