Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
correct logic, add other ops
  • Loading branch information
kabra1110 committed Jul 9, 2023
commit e11b611518211bfd3e26b59f9b055835503f9e97
24 changes: 17 additions & 7 deletions integration_tests/test_list_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@ def test_list_compare():
t2: tuple[i32, i32]
i: i32

assert l1 < l2
assert l1 < l2 and l1 <= l2
assert not l1 > l2 and not l1 >= l2
i = l2.pop()
i = l2.pop()
assert l2 < l1
assert l2 < l1 and l1 > l2 and l1 >= l2
assert not (l1 < l2)

l1 = [3,4,5]
l2 = [1,6,7]
assert l2 < l1
l1 = [3, 4, 5]
l2 = [1, 6, 7]
assert l1 > l2 and l1 >= l2
assert not l1 < l2 and not l1 <= l2

assert l3 < l4
l1 = l2
assert l1 == l2 and l1 <= l2 and l1 >= l2
assert not l1 < l2 and not l1 > l2

assert l4 > l3 and l4 >= l3
l4[0] = l3[0]
assert l4 < l3

Expand All @@ -35,9 +41,13 @@ def test_list_compare():
l5[1] = l7
if i % 2 == 0:
assert l5[1 - i % 2] < l5[i % 2]
assert l5[1 - i % 2] <= l5[i % 2]
assert not l5[1 - i % 2] > l5[i % 2]
assert not l5[1 - i % 2] >= l5[i % 2]

t1 = (1, 2)
t2 = (2, 3)
assert t1 < t2
assert t1 < t2 and t1 <= t2
assert not t1 > t2 and not t1 >= t2

test_list_compare()
35 changes: 31 additions & 4 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr(*x.m_right);
llvm::Value* right = tmp;
ptr_loads = ptr_loads_copy;

ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4));

if(x.m_op == ASR::cmpopType::Eq || x.m_op == ASR::cmpopType::NotEq) {
tmp = llvm_utils->is_equal_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left));
Expand All @@ -1896,8 +1899,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
else if(x.m_op == ASR::cmpopType::Lt) {
tmp = llvm_utils->is_less_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left));
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 0, int32_type);
}
else if(x.m_op == ASR::cmpopType::LtE) {
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 1, int32_type);
}
else if(x.m_op == ASR::cmpopType::Gt) {
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 2, int32_type);
}
else if(x.m_op == ASR::cmpopType::GtE) {
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 3, int32_type);
}
}

Expand Down Expand Up @@ -2237,8 +2252,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}
else if(x.m_op == ASR::cmpopType::Lt) {
tmp = llvm_utils->is_less_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left));
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 0);
}
else if(x.m_op == ASR::cmpopType::LtE) {
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 1);
}
else if(x.m_op == ASR::cmpopType::Gt) {
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 2);
}
else if(x.m_op == ASR::cmpopType::GtE) {
tmp = llvm_utils->is_ineq_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left), 3);
}
}

Expand Down
136 changes: 103 additions & 33 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,67 @@ namespace LCompilers {
}
}

llvm::Value* LLVMUtils::is_less_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type) {
llvm::Value* LLVMUtils::is_ineq_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type,
int8_t overload_id, ASR::ttype_t* int32_type) {
/**
* overloads:
* 0 <
* 1 <=
* 2 >
* 3 >=
*/
llvm::CmpInst::Predicate pred;

switch( asr_type->type ) {
case ASR::ttypeType::Integer: {
return builder->CreateICmpSLT(left, right);
}
case ASR::ttypeType::Integer:
case ASR::ttypeType::Logical: {
return builder->CreateICmpSLT(left, right); // signed?
switch( overload_id ) {
case 0: {
pred = llvm::CmpInst::Predicate::ICMP_SLT;
break;
}
case 1: {
pred = llvm::CmpInst::Predicate::ICMP_SLE;
break;
}
case 2: {
pred = llvm::CmpInst::Predicate::ICMP_SGT;
break;
}
case 3: {
pred = llvm::CmpInst::Predicate::ICMP_SGE;
break;
}
default: {
// can exit with error
}
}
return builder->CreateCmp(pred, left, right);
}
case ASR::ttypeType::Real: {
return builder->CreateFCmpOLT(left, right);
switch( overload_id ) {
case 0: {
pred = llvm::CmpInst::Predicate::FCMP_OLT;
break;
}
case 1: {
pred = llvm::CmpInst::Predicate::FCMP_OLE;
break;
}
case 2: {
pred = llvm::CmpInst::Predicate::FCMP_OGT;
break;
}
case 3: {
pred = llvm::CmpInst::Predicate::FCMP_OGE;
break;
}
default: {
// can exit with error
}
}
return builder->CreateCmp(pred, left, right);
}
case ASR::ttypeType::Character: {
if( !are_iterators_set ) {
Expand All @@ -346,7 +396,28 @@ namespace LCompilers {
builder->CreateICmpNE(l, null_char),
builder->CreateICmpNE(r, null_char)
);
cond = builder->CreateAnd(cond, builder->CreateICmpULT(l, r)); // unsigned?
switch( overload_id ) {
case 0: {
pred = llvm::CmpInst::Predicate::ICMP_ULT;
break;
}
case 1: {
pred = llvm::CmpInst::Predicate::ICMP_ULE;
break;
}
case 2: {
pred = llvm::CmpInst::Predicate::ICMP_UGT;
break;
}
case 3: {
pred = llvm::CmpInst::Predicate::ICMP_UGE;
break;
}
default: {
// can exit with error
}
}
cond = builder->CreateAnd(cond, builder->CreateCmp(pred, l, r));
builder->CreateCondBr(cond, loopbody, loopend);
}

Expand All @@ -371,12 +442,13 @@ namespace LCompilers {
case ASR::ttypeType::Tuple: {
ASR::Tuple_t* tuple_type = ASR::down_cast<ASR::Tuple_t>(asr_type);
return tuple_api->check_tuple_inequality(left, right, tuple_type, context,
builder, module);
builder, module, overload_id);
}
case ASR::ttypeType::List: {
ASR::List_t* list_type = ASR::down_cast<ASR::List_t>(asr_type);
return list_api->check_list_inequality(left, right, list_type->m_type,
context, builder, module);
context, builder, module,
overload_id, int32_type);
}
default: {
throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " +
Expand Down Expand Up @@ -3185,28 +3257,23 @@ namespace LCompilers {
ASR::ttype_t* item_type,
llvm::LLVMContext& context,
llvm::IRBuilder<>* builder,
llvm::Module& module) {
// TODO:
// - ineq operations other than "<"
// - abstract out this code, possibly switch over operators
// - short-circuit without initial allocation of res? Also for equality

llvm::Module& module, int8_t overload_id,
ASR::ttype_t* int32_type) {
/**
* Equivalent in C++
* For "<"
*
* equality_holds = 1;
* inequality_holds = 0;
* i = 0;
*
* while( i < a_len && i < b_len && equality_holds ) {
* equality_holds &= (a[i] == b[i]);
* inequality_holds |= (a[i] < b[i]);
* inequality_holds |= (a[i] op b[i]);
* i++;
* }
*
* if( i == a_len && a_len < b_len && equality_holds ) {
* inequality_holds = 1;
* if( (i == a_len || i == b_len) && equality_holds ) {
* inequality_holds = a_len op b_len;
* }
*
*/
Expand Down Expand Up @@ -3247,8 +3314,8 @@ namespace LCompilers {
false, module, LLVM::is_llvm_struct(item_type));
llvm::Value* right_arg = llvm_utils->list_api->read_item(l2, i,
false, module, LLVM::is_llvm_struct(item_type));
llvm::Value* res = llvm_utils->is_less_by_value(left_arg, right_arg, module,
item_type);
llvm::Value* res = llvm_utils->is_ineq_by_value(left_arg, right_arg, module,
item_type, overload_id);
res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res);
LLVM::CreateStore(*builder, res, inequality_holds);
res = llvm_utils->is_equal_by_value(left_arg, right_arg, module,
Expand All @@ -3267,13 +3334,15 @@ namespace LCompilers {

llvm::Value* cond = builder->CreateICmpEQ(LLVM::CreateLoad(*builder, idx),
a_len);
cond = builder->CreateAnd(cond, builder->CreateICmpSLT(a_len, b_len));
cond = builder->CreateOr(cond, builder->CreateICmpEQ(
LLVM::CreateLoad(*builder, idx), b_len));
cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, equality_holds));
llvm_utils->create_if_else(cond, [&]() {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(1, 1)), inequality_holds);
LLVM::CreateStore(*builder, llvm_utils->is_ineq_by_value(a_len, b_len,
module, int32_type, overload_id), inequality_holds);
}, []() {
// will be already 0 from the loop
// LLVM::CreateStore(*builder, llvm::ConstantInt::get(
// context, llvm::APInt(1, 0)), inequality_holds);
});

return LLVM::CreateLoad(*builder, inequality_holds);
Expand Down Expand Up @@ -3428,23 +3497,24 @@ namespace LCompilers {
ASR::Tuple_t* tuple_type,
llvm::LLVMContext& context,
llvm::IRBuilder<>* builder,
llvm::Module& module) {
// TODO: operators other than "<"

llvm::Module& module, int8_t overload_id) {
/**
* Equivalent in C++
* For "<"
*
* equality_holds = 1;
* inequality_holds = 0;
* i = 0;
*
* // owing to compile-time access of indices,
* // loop is unrolled into multiple if statements
* while( i < a_len && equality_holds ) {
* inequality_holds |= (a[i] < b[i]);
* inequality_holds |= (a[i] op b[i]);
* equality_holds &= (a[i] == b[i]);
* i++;
* }
*
* return inequality_holds;
*
*/

llvm::AllocaInst *equality_holds = builder->CreateAlloca(
Expand All @@ -3462,8 +3532,8 @@ namespace LCompilers {
tuple_type->m_type[i]));
llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct(
tuple_type->m_type[i]));
llvm::Value* res = llvm_utils->is_less_by_value(t1i, t2i, module,
tuple_type->m_type[i]);
llvm::Value* res = llvm_utils->is_ineq_by_value(t1i, t2i, module,
tuple_type->m_type[i], overload_id);
res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res);
LLVM::CreateStore(*builder, res, inequality_holds);
res = llvm_utils->is_equal_by_value(t1i, t2i, module, tuple_type->m_type[i]);
Expand Down
13 changes: 8 additions & 5 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ namespace LCompilers {
llvm::Value* is_equal_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type);

llvm::Value* is_less_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type);
llvm::Value* is_ineq_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type,
int8_t overload_id, ASR::ttype_t* int32_type=nullptr);

void set_iterators();

Expand Down Expand Up @@ -289,8 +290,10 @@ namespace LCompilers {
llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,
llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module);

llvm::Value* check_list_inequality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,
llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module);
llvm::Value* check_list_inequality(llvm::Value* l1, llvm::Value* l2,
ASR::ttype_t *item_type, llvm::LLVMContext& context,
llvm::IRBuilder<>* builder, llvm::Module& module,
int8_t overload_id, ASR::ttype_t* int32_type=nullptr);

void list_repeat_copy(llvm::Value* repeat_list, llvm::Value* init_list,
llvm::Value* num_times, llvm::Value* init_list_len,
Expand Down Expand Up @@ -335,7 +338,7 @@ namespace LCompilers {

llvm::Value* check_tuple_inequality(llvm::Value* t1, llvm::Value* t2,
ASR::Tuple_t* tuple_type, llvm::LLVMContext& context,
llvm::IRBuilder<>* builder, llvm::Module& module);
llvm::IRBuilder<>* builder, llvm::Module& module, int8_t overload_id);

void concat(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type_1,
ASR::Tuple_t* tuple_type_2, llvm::Value* concat_tuple,
Expand Down
10 changes: 6 additions & 4 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6209,16 +6209,18 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = ASR::make_StringCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
} else if (ASR::is_a<ASR::Tuple_t>(*dest_type)) {
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq
&& asr_op != ASR::cmpopType::Lt) {
throw SemanticError("Only Equal, Not-equal and Less-than operators "
&& asr_op != ASR::cmpopType::Lt && asr_op != ASR::cmpopType::LtE
&& asr_op != ASR::cmpopType::Gt && asr_op != ASR::cmpopType::GtE) {
throw SemanticError("Only ==, !=, <, <=, >, >= operators "
"are supported for Tuples",
x.base.base.loc);
}
tmp = ASR::make_TupleCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
} else if (ASR::is_a<ASR::List_t>(*dest_type)) {
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq
&& asr_op != ASR::cmpopType::Lt) {
throw SemanticError("Only Equal, Not-equal and Less-than operators "
&& asr_op != ASR::cmpopType::Lt && asr_op != ASR::cmpopType::LtE
&& asr_op != ASR::cmpopType::Gt && asr_op != ASR::cmpopType::GtE) {
throw SemanticError("Only ==, !=, <, <=, >, >= operators "
"are supported for Lists",
x.base.base.loc);
}
Expand Down