Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ RUN(NAME test_list_repeat LABELS cpython llvm NOFAST)
RUN(NAME test_list_reverse LABELS cpython llvm)
RUN(NAME test_list_pop LABELS cpython llvm NOFAST) # TODO: Remove NOFAST from here.
RUN(NAME test_list_pop2 LABELS cpython llvm NOFAST) # TODO: Remove NOFAST from here.
RUN(NAME test_list_compare LABELS cpython llvm)
RUN(NAME test_tuple_01 LABELS cpython llvm c)
RUN(NAME test_tuple_02 LABELS cpython llvm c NOFAST)
RUN(NAME test_tuple_03 LABELS cpython llvm c)
Expand Down
43 changes: 43 additions & 0 deletions integration_tests/test_list_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from lpython import i32, f64

def test_list_compare():
l1: list[i32] = [1, 2, 3]
l2: list[i32] = [1, 2, 3, 4]
l3: list[tuple[i32, f64, str]] = [(1, 2.0, 'a'), (3, 4.0, 'b')]
l4: list[tuple[i32, f64, str]] = [(1, 3.0, 'a')]
l5: list[list[str]] = [[''], ['']]
l6: list[str] = []
l7: list[str] = []
t1: tuple[i32, i32]
t2: tuple[i32, i32]
i: i32

assert l1 < l2
i = l2.pop()
i = l2.pop()
assert l2 < l1
assert not (l1 < l2)

l1 = [3,4,5]
l2 = [1,6,7]
assert l2 < l1

# assert l3 < l4
# l4[0] = l3[0]
# assert l4 < l3

for i in range(0, 10):
if i % 2 == 0:
l6.append('a')
else:
l7.append('a')
l5[0] = l6
l5[1] = l7
if i % 2 == 0:
assert l5[1 - i % 2] < l5[i % 2]

# t1 = (1, 2)
# t2 = (3, 4)
# assert t1 < t2

test_list_compare()
14 changes: 10 additions & 4 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1888,10 +1888,16 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr(*x.m_right);
llvm::Value* right = tmp;
ptr_loads = ptr_loads_copy;
tmp = llvm_utils->is_equal_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left));
if (x.m_op == ASR::cmpopType::NotEq) {
tmp = builder->CreateNot(tmp);
if(x.m_op == ASR::cmpopType::Eq || x.m_op == ASR::cmpopType::NotEq) {
tmp = llvm_utils->is_equal_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left));
if (x.m_op == ASR::cmpopType::NotEq) {
tmp = builder->CreateNot(tmp);
}
}
else if(x.m_op == ASR::cmpopType::Lt) {
tmp = llvm_utils->is_less_by_value(left, right, *module,
ASRUtils::expr_type(x.m_left));
}
}

Expand Down
247 changes: 247 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,81 @@ namespace LCompilers {
}
}

llvm::Value* LLVMUtils::is_less_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type) {
switch( asr_type->type ) {
case ASR::ttypeType::Integer: {
return builder->CreateICmpSLT(left, right);
}
case ASR::ttypeType::Logical: {
return builder->CreateICmpSLT(left, right); // signed?
}
case ASR::ttypeType::Real: {
return builder->CreateFCmpOLT(left, right);
}
case ASR::ttypeType::Character: {
if( !are_iterators_set ) {
str_cmp_itr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
}
llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context),
llvm::APInt(8, '\0'));
llvm::Value* idx = str_cmp_itr;
LLVM::CreateStore(*builder,
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)),
idx);
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
start_new_block(loophead);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i));
llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i));
llvm::Value *cond = builder->CreateAnd(
builder->CreateICmpNE(l, null_char),
builder->CreateICmpNE(r, null_char)
);
cond = builder->CreateAnd(cond, builder->CreateICmpULT(l, r)); // unsigned?
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
start_new_block(loopbody);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, i, idx);
}

builder->CreateBr(loophead);

// end
start_new_block(loopend);
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i));
llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i));
return builder->CreateICmpULT(l, r);
}
case ASR::ttypeType::Tuple: {
ASR::Tuple_t* tuple_type = ASR::down_cast<ASR::Tuple_t>(asr_type);
return tuple_api->check_tuple_inequality(left, right, tuple_type, context,
builder, module);
}
case ASR::ttypeType::List: {
ASR::List_t* list_type = ASR::down_cast<ASR::List_t>(asr_type);
return list_api->check_list_inequality(left, right, list_type->m_type,
context, builder, module);
}
default: {
throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " +
ASRUtils::type_to_str_python(asr_type));
}
}
}

void LLVMUtils::deepcopy(llvm::Value* src, llvm::Value* dest,
ASR::ttype_t* asr_type, llvm::Module* module,
std::map<std::string, std::map<std::string, int>>& name2memidx) {
Expand Down Expand Up @@ -3106,6 +3181,103 @@ namespace LCompilers {
return LLVM::CreateLoad(*builder, is_equal);
}

llvm::Value* LLVMList::check_list_inequality(llvm::Value* l1, llvm::Value* l2,
ASR::ttype_t* item_type,
llvm::LLVMContext& context,
llvm::IRBuilder<>* builder,
llvm::Module& module) {
// TODO:
// - ineq operations other than "<"
// - abstract out this code, possibly switch over operators
// - short-circuit without initial allocation of res? Also for equality

/**
* Equivalent in C++
* For "<"
*
* equality_holds = 1;
* inequality_holds = 0;
* i = 0;
*
* while( i < a_len && i < b_len && equality_holds ) {
* equality_holds &= (a[i] == b[i]);
* inequality_holds |= (a[i] < b[i]);
* }
*
* if( i == a_len && a_len < b_len && equality_holds ) {
* inequality_holds = 1;
* }
*
*/

llvm::AllocaInst *equality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)),
equality_holds);
llvm::AllocaInst *inequality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)),
inequality_holds);

llvm::Value *a_len = llvm_utils->list_api->len(l1);
llvm::Value *b_len = llvm_utils->list_api->len(l2);
llvm::AllocaInst *idx = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(32, 0)), idx);
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* cnd = builder->CreateICmpSLT(i, a_len);
cnd = builder->CreateAnd(cnd, builder->CreateICmpSLT(i, b_len));
cnd = builder->CreateAnd(cnd, LLVM::CreateLoad(*builder, equality_holds));
builder->CreateCondBr(cnd, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* left_arg = llvm_utils->list_api->read_item(l1, i,
false, module, LLVM::is_llvm_struct(item_type));
llvm::Value* right_arg = llvm_utils->list_api->read_item(l2, i,
false, module, LLVM::is_llvm_struct(item_type));
llvm::Value* res = llvm_utils->is_less_by_value(left_arg, right_arg, module,
item_type);
res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res);
LLVM::CreateStore(*builder, res, inequality_holds);
res = llvm_utils->is_equal_by_value(left_arg, right_arg, module,
item_type);
res = builder->CreateAnd(LLVM::CreateLoad(*builder, equality_holds), res);
LLVM::CreateStore(*builder, res, equality_holds);
i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, i, idx);
}

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);

llvm::Value* cond = builder->CreateICmpEQ(LLVM::CreateLoad(*builder, idx),
a_len);
cond = builder->CreateAnd(cond, builder->CreateICmpSLT(a_len, b_len));
cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, equality_holds));
llvm_utils->create_if_else(cond, [&]() {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(1, 1)), inequality_holds);
}, [=]() {
// will be already 0 from the loop
});

return LLVM::CreateLoad(*builder, inequality_holds);
}

void LLVMList::list_repeat_copy(llvm::Value* repeat_list, llvm::Value* init_list,
llvm::Value* num_times, llvm::Value* init_list_len,
llvm::Module* module) {
Expand Down Expand Up @@ -3251,6 +3423,81 @@ namespace LCompilers {
return is_equal;
}

llvm::Value* LLVMTuple::check_tuple_inequality(llvm::Value* t1, llvm::Value* t2,
ASR::Tuple_t* tuple_type,
llvm::LLVMContext& context,
llvm::IRBuilder<>* builder,
llvm::Module& module) {
// TODO: operators other than "<"

/**
* Equivalent in C++
* For "<"
*
* equality_holds = 1;
* inequality_holds = 0;
* i = 0;
*
* while( i < a_len && equality_holds ) {
* equality_holds &= (a[i] == b[i]);
* inequality_holds |= (a[i] < b[i]);
* }
*
*/

llvm::AllocaInst *equality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)),
equality_holds);
llvm::AllocaInst *inequality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)),
inequality_holds);

llvm::AllocaInst *idx = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(32, 0)), idx);
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* cnd = builder->CreateICmpSLT(i, llvm::ConstantInt::get(
context, llvm::APInt(32, tuple_type->n_type)));
cnd = builder->CreateAnd(cnd, LLVM::CreateLoad(*builder, equality_holds));
builder->CreateCondBr(cnd, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
// llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i, LLVM::is_llvm_struct(
// tuple_type->m_type[i]));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we index tuple_type->m_type with run time value i?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't. :-). Tuples should only be indexed with fixed indices available at compile time so that type of the indexed item can be figured out at compile time itself. Its the pattern as C++ std::tuple.

// llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct(
// tuple_type->m_type[i]));
// llvm::Value* res = llvm_utils->is_less_by_value(t1i, t2i, module,
// tuple_type->m_type[i]);
// res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res);
// LLVM::CreateStore(*builder, res, inequality_holds);
// res = llvm_utils->is_equal_by_value(t1i, t2i, module, tuple_type->m_type[i]);
// res = builder->CreateAnd(LLVM::CreateLoad(*builder, equality_holds), res);
// LLVM::CreateStore(*builder, res, equality_holds);
i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, i, idx);
}

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);
return LLVM::CreateLoad(*builder, inequality_holds);
}

void LLVMTuple::concat(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type_1,
ASR::Tuple_t* tuple_type_2, llvm::Value* concat_tuple,
ASR::Tuple_t* concat_tuple_type, llvm::Module& module,
Expand Down
10 changes: 10 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ namespace LCompilers {
llvm::Value* is_equal_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type);

llvm::Value* is_less_by_value(llvm::Value* left, llvm::Value* right,
llvm::Module& module, ASR::ttype_t* asr_type);

void set_iterators();

void reset_iterators();
Expand Down Expand Up @@ -286,6 +289,9 @@ namespace LCompilers {
llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,
llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module);

llvm::Value* check_list_inequality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,
llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module);

void list_repeat_copy(llvm::Value* repeat_list, llvm::Value* init_list,
llvm::Value* num_times, llvm::Value* init_list_len,
llvm::Module* module);
Expand Down Expand Up @@ -327,6 +333,10 @@ namespace LCompilers {
ASR::Tuple_t* tuple_type, llvm::LLVMContext& context,
llvm::IRBuilder<>* builder, llvm::Module& module);

llvm::Value* check_tuple_inequality(llvm::Value* t1, llvm::Value* t2,
ASR::Tuple_t* tuple_type, llvm::LLVMContext& context,
llvm::IRBuilder<>* builder, llvm::Module& module);

void concat(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type_1,
ASR::Tuple_t* tuple_type_2, llvm::Value* concat_tuple,
ASR::Tuple_t* concat_tuple_type, llvm::Module& module,
Expand Down
12 changes: 8 additions & 4 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6208,14 +6208,18 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {

tmp = ASR::make_StringCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
} else if (ASR::is_a<ASR::Tuple_t>(*dest_type)) {
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq) {
throw SemanticError("Only Equal and Not-equal operators are supported for Tuples",
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq
&& asr_op != ASR::cmpopType::Lt) {
throw SemanticError("Only Equal, Not-equal and Less-than operators "
"are supported for Tuples",
x.base.base.loc);
}
tmp = ASR::make_TupleCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
} else if (ASR::is_a<ASR::List_t>(*dest_type)) {
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq) {
throw SemanticError("Only Equal and Not-equal operators are supported for Tuples",
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq
&& asr_op != ASR::cmpopType::Lt) {
throw SemanticError("Only Equal, Not-equal and Less-than operators "
"are supported for Lists",
x.base.base.loc);
}
tmp = ASR::make_ListCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
Expand Down