Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
dbb001f
[ASR Pass] Symbolic: Use a function to create `basic_new_stack` BindC…
Thirumalai-Shaktivel Nov 25, 2023
ed2997d
[ASR Pass] Symbolic: Use a function to create `basic_free_stack` Bind…
Thirumalai-Shaktivel Nov 25, 2023
2fe8807
[ASR Pass] Symbolic: Add `basic_free_stack` to function dependencies
Thirumalai-Shaktivel Nov 25, 2023
0c3d6f2
[ASR Pass] Symbolic: Simplify `basic_get_args` to return `SubroutineC…
Thirumalai-Shaktivel Nov 25, 2023
93546df
[ASR Pass] Symbolic: Simplify `vecbasic_new` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
799932e
[ASR Pass] Symbolic: Simplify `vecbasic_get` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
a7eae7b
[ASR Pass] Symbolic: Simplify `vecbasic_size` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
f497d07
[ASR Pass] Symbolic: Simplify `basic_assign` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
0aa4435
[ASR Pass] Symbolic: Simplify `basic_str` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
3096c28
[ASR Pass] Symbolic: Simplify `basic_get_type` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
bb48bdb
[ASR Pass] Symbolic: Simplify `basic_eq` & `basic_neq` into `basic_co…
Thirumalai-Shaktivel Nov 25, 2023
e8724c1
[ASR Pass] Symbolic: Simplify `integer_set_si` to return `SubroutineC…
Thirumalai-Shaktivel Nov 25, 2023
2745451
[ASR Pass] Symbolic: Simplify `symbol_set` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
9440150
[ASR Pass] Symbolic: Simplify `basic_const` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
3903658
[ASR Pass] Symbolic: Simplify `basic_binop` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
79066ee
[ASR Pass] Symbolic: Simplify `basic_unaryop` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
d331f27
[ASR Pass] Symbolic: Simplify `process_intrinsic_function` arguments
Thirumalai-Shaktivel Nov 25, 2023
f18ae18
[ASR Pass] Symbolic: Simplify `process_intrinsic_function` to use macros
Thirumalai-Shaktivel Nov 25, 2023
760380a
[ASR Pass] Symbolic: Simplify `process_attributes` to use macros
Thirumalai-Shaktivel Nov 25, 2023
f6d0bd6
[ASR Pass] Symbolic: Simplify `basic_has_symbol` to return `FunctionC…
Thirumalai-Shaktivel Nov 25, 2023
6a7b2cd
[ASR Pass] Symbolic: Simplify `process_attributes` arguments
Thirumalai-Shaktivel Nov 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[ASR Pass] Symbolic: Simplify basic_assign to return SubroutineCall
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
commit f497d07128198aa313d3b76e39fba15143dc2ee5
122 changes: 52 additions & 70 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,52 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
vecbasic_size_sym, vecbasic_size_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
}

ASR::stmt_t* basic_assign(const Location& loc,
ASR::expr_t *target, ASR::expr_t *value) {
std::string fn_name = "basic_assign";
symbolic_dependencies.push_back(fn_name);
ASR::ttype_t *cptr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
ASR::symbol_t *basic_assign_sym = current_scope->resolve_symbol(fn_name);
if ( !basic_assign_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, cptr_type,
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, cptr_type,
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
basic_assign_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dep.p, dep.n, args.p, args.n, body.p, body.n, nullptr,
ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, fn_name), basic_assign_sym);
}
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
call_arg.m_value = value;
call_args.push_back(al, call_arg);
return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_assign_sym,
basic_assign_sym, call_args.p, call_args.n, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -780,45 +826,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::symbol_t* declare_basic_assign_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_assign";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_str_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_str";
symbolic_dependencies.push_back(name);
Expand Down Expand Up @@ -1197,22 +1204,9 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
if (ASR::is_a<ASR::Var_t>(*x.m_value) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(x.m_value))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope);
ASR::symbol_t* var_sym = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = target;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym,
basic_assign_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, basic_assign(x.base.base.loc, x.m_target,
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym))));
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_value)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_value);
if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) {
Expand Down Expand Up @@ -1305,22 +1299,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
} else if (ASR::is_a<ASR::ListItem_t>(*x.m_value)) {
ASR::ListItem_t* list_item = ASR::down_cast<ASR::ListItem_t>(x.m_value);
if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) {
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope);

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a,
list_item->m_pos, CPtr_type, nullptr));
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym,
basic_assign_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
ASR::expr_t *value = ASRUtils::EXPR(ASR::make_ListItem_t(al,
x.base.base.loc, list_item->m_a, list_item->m_pos,
ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), nullptr));
pass_result.push_back(al, basic_assign(x.base.base.loc, x.m_target, value));
}
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_value)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_value);
Expand Down