Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix dict keys and values for LP
  • Loading branch information
kabra1110 committed Jul 8, 2023
commit f28eaf06b676fb0a5500be20415e36913dbeb555
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
RUN(NAME test_dict_13 LABELS cpython llvm c)
RUN(NAME test_dict_bool LABELS cpython llvm)
RUN(NAME test_dict_increment LABELS cpython llvm)
RUN(NAME test_dict_keys_values LABELS cpython llvm)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
43 changes: 43 additions & 0 deletions integration_tests/test_dict_keys_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from lpython import i32, f64

def test_dict_keys_values():
d1: dict[i32, i32] = {}
d2: dict[tuple[i32, i32], tuple[i32, tuple[str, f64]]] = {}
k1: list[i32]
k2: list[tuple[i32, i32]]
v1: list[i32]
v2: list[tuple[i32, tuple[str, f64]]]
i: i32
j: i32
key_count: i32
s: str

for i in range(105, 115):
d1[i] = i + 1
k1 = d1.keys()
v1 = d1.values()
assert len(k1) == 10
for i in range(105, 115):
key_count = 0
for j in range(len(k1)):
if k1[j] == i:
key_count += 1
assert v1[j] == d1[i]
assert key_count == 1

s = 'a'
for i in range(10):
d2[(i, i + 1)] = (i, (s, f64(i * i)))
s += 'a'
k2 = d2.keys()
v2 = d2.values()
assert len(k2) == 10
for i in range(10):
key_count = 0
for j in range(len(k2)):
if k2[j] == (i, i + 1):
key_count += 1
assert v2[j] == d2[k2[j]]
assert key_count == 1

test_dict_keys_values()
43 changes: 40 additions & 3 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2072,18 +2072,51 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
}

void generate_DictKeys(ASR::expr_t* m_arg) {
void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value, const Location &loc) {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
ASRUtils::expr_type(m_arg));
ASR::ttype_t* el_type = key_or_value == 0 ?
dict_type->m_key_type : dict_type->m_value_type;

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pdict = tmp;

set_dict_api(dict_type);
if(llvm_utils->dict_api == dict_api_sc.get()) {
throw CodeGenError("dict.keys and dict.values are only implemented "
"for linear probing for now", loc);
}
ptr_loads = ptr_loads_copy;
tmp = llvm_utils->dict_api->get_key_list(pdict);

bool is_array_type_local = false, is_malloc_array_type_local = false;
bool is_list_local = false;
ASR::dimension_t* m_dims_local = nullptr;
int n_dims_local = -1, a_kind_local = -1;
llvm::Type* llvm_el_type = get_type_from_ttype_t(el_type,
nullptr,
ASR::storage_typeType::Default, is_array_type_local,
is_malloc_array_type_local, is_list_local, m_dims_local,
n_dims_local, a_kind_local);
std::string type_code = ASRUtils::get_type_code(el_type);
int32_t type_size = -1;
if( ASR::is_a<ASR::Character_t>(*el_type) ||
LLVM::is_llvm_struct(el_type) ||
ASR::is_a<ASR::Complex_t>(*el_type) ) {
llvm::DataLayout data_layout(module.get());
type_size = data_layout.getTypeAllocSize(llvm_el_type);
} else {
type_size = ASRUtils::extract_kind_from_ttype_t(el_type);
}
llvm::Type* el_list_type = list_api->get_list_type(llvm_el_type, type_code, type_size);
llvm::Value* el_list = builder->CreateAlloca(el_list_type, nullptr, key_or_value == 0 ?
"keys_list" : "values_list");
list_api->list_init(type_code, el_list, *module, 0, 0);

llvm_utils->dict_api->get_elements_list(pdict, el_list, el_type, *module,
name2memidx, key_or_value);
tmp = el_list;
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
Expand Down Expand Up @@ -2130,7 +2163,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
break;
}
case ASRUtils::IntrinsicFunctions::DictKeys: {
generate_DictKeys(x.m_args[0]);
generate_DictElems(x.m_args[0], 0, x.base.base.loc);
break;
}
case ASRUtils::IntrinsicFunctions::DictValues: {
generate_DictElems(x.m_args[0], 1, x.base.base.loc);
break;
}
case ASRUtils::IntrinsicFunctions::Exp: {
Expand Down
89 changes: 89 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ namespace LCompilers {
list_api->list_deepcopy(src, dest, list_type, module, name2memidx);
break ;
}
case ASR::ttypeType::Dict: {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
// set dict api here?
dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx);
break ;
}
case ASR::ttypeType::Struct: {
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(asr_type);
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
Expand Down Expand Up @@ -2469,6 +2475,89 @@ namespace LCompilers {
return LLVM::CreateLoad(*builder, value_ptr);
}

void LLVMDict::get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) {

/**
* C++ equivalent:
*
* idx = 0;
*
* while( capacity > idx ) {
* el = key_or_value_list[idx];
* key_mask_value = key_mask[idx];
*
* is_key_skip = key_mask_value == 3; // tombstone
* is_key_set = key_mask_value != 0;
* add_el = is_key_set && !is_key_skip;
* if( add_el ) {
* elements_list.append(el);
* }
*
* idx++;
* }
*
*/

llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict);
if( !are_iterators_set ) {
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
}
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 0)), idx_ptr);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(key_mask, idx));
llvm::Value* is_key_skip = builder->CreateICmpEQ(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
llvm::Value* is_key_set = builder->CreateICmpNE(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));

llvm::Value* add_el = builder->CreateAnd(is_key_set,
builder->CreateNot(is_key_skip));
llvm_utils->create_if_else(add_el, [&]() {
llvm::Value* el = llvm_utils->list_api->read_item(el_list, idx,
false, module, LLVM::is_llvm_struct(el_asr_type));
llvm_utils->list_api->append(elements_list, el,
el_asr_type, &module, name2memidx);
}, [=]() {
});

idx = builder->CreateAdd(idx, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, idx, idx_ptr);
}

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);
}

void LLVMDictSeparateChaining::get_elements_list(llvm::Value* /*dict*/,
llvm::Value* /*elements_list*/, ASR::ttype_t* /*el_asr_type*/, llvm::Module& /*module*/,
std::map<std::string, std::map<std::string, int>>& /*name2memidx*/,
bool /*key_or_value*/) {}

llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
bool enable_bounds_checking,
llvm::Module& module, bool get_pointer) {
Expand Down
16 changes: 16 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ namespace LCompilers {
virtual
void set_is_dict_present(bool value);

virtual
void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) = 0;

virtual ~LLVMDictInterface() = 0;

};
Expand Down Expand Up @@ -555,6 +561,11 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value);

virtual ~LLVMDict();
};

Expand Down Expand Up @@ -702,6 +713,11 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value);

virtual ~LLVMDictSeparateChaining();

};
Expand Down
57 changes: 55 additions & 2 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ enum class IntrinsicFunctions : int64_t {
ListReverse,
ListPop,
DictKeys,
DictValues,
SymbolicSymbol,
SymbolicAdd,
SymbolicSub,
Expand Down Expand Up @@ -1148,8 +1149,8 @@ static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnost
ASRUtils::require_impl(ASR::is_a<ASR::Dict_t>(*ASRUtils::expr_type(x.m_args[0])),
"Argument to dict.keys must be of dict type",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASRUtils::check_equal_type(
ASRUtils::get_contained_type(x.m_type),
ASRUtils::require_impl(ASR::is_a<ASR::List_t>(*x.m_type) &&
ASRUtils::check_equal_type(ASRUtils::get_contained_type(x.m_type),
ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 0)),
"Return type of dict.keys must be of list of dict key element type",
x.base.base.loc, diagnostics);
Expand Down Expand Up @@ -1186,6 +1187,52 @@ static inline ASR::asr_t* create_DictKeys(Allocator& al, const Location& loc,

} // namespace DictKeys

namespace DictValues {

static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 1, "Call to dict.values must have no argument",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::Dict_t>(*ASRUtils::expr_type(x.m_args[0])),
"Argument to dict.values must be of dict type",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::List_t>(*x.m_type) &&
ASRUtils::check_equal_type(ASRUtils::get_contained_type(x.m_type),
ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 1)),
"Return type of dict.values must be of list of dict value element type",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_dict_values(Allocator &/*al*/,
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
// TODO: To be implemented for DictConstant expression
return nullptr;
}

static inline ASR::asr_t* create_DictValues(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 1) {
err("Call to dict.values must have no argument", loc);
}

ASR::expr_t* dict_expr = args[0];
ASR::ttype_t *type = ASRUtils::expr_type(dict_expr);
ASR::ttype_t *dict_values_type = ASR::down_cast<ASR::Dict_t>(type)->m_value_type;

Vec<ASR::expr_t*> arg_values;
arg_values.reserve(al, args.size());
for( size_t i = 0; i < args.size(); i++ ) {
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
}
ASR::expr_t* compile_time_value = eval_dict_values(al, loc, arg_values);
ASR::ttype_t *to_type = List(dict_values_type);
return ASR::make_IntrinsicFunction_t(al, loc,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictValues),
args.p, args.size(), 0, to_type, compile_time_value);
}

} // namespace DictValues

namespace Any {

static inline void verify_array(ASR::expr_t* array, ASR::ttype_t* return_type,
Expand Down Expand Up @@ -2261,6 +2308,8 @@ namespace IntrinsicFunctionRegistry {
{nullptr, &ListReverse::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictKeys),
{nullptr, &DictKeys::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictValues),
{nullptr, &DictValues::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol),
{nullptr, &SymbolicSymbol::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
Expand Down Expand Up @@ -2317,6 +2366,8 @@ namespace IntrinsicFunctionRegistry {
"list.pop"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictKeys),
"dict.keys"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictValues),
"dict.values"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol),
"Symbol"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
Expand Down Expand Up @@ -2363,6 +2414,7 @@ namespace IntrinsicFunctionRegistry {
{"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},
{"dict.keys", {&DictKeys::create_DictKeys, &DictKeys::eval_dict_keys}},
{"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}},
{"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}},
{"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}},
{"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}},
Expand Down Expand Up @@ -2478,6 +2530,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(ListReverse)
INTRINSIC_NAME_CASE(ListPop)
INTRINSIC_NAME_CASE(DictKeys)
INTRINSIC_NAME_CASE(DictValues)
INTRINSIC_NAME_CASE(SymbolicSymbol)
INTRINSIC_NAME_CASE(SymbolicAdd)
INTRINSIC_NAME_CASE(SymbolicSub)
Expand Down
17 changes: 16 additions & 1 deletion src/lpython/semantics/python_attribute_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ struct AttributeHandler {
{"set@remove", &eval_set_remove},
{"dict@get", &eval_dict_get},
{"dict@pop", &eval_dict_pop},
{"dict@keys", &eval_dict_keys}
{"dict@keys", &eval_dict_keys},
{"dict@values", &eval_dict_values}
};

modify_attr_set = {"list@append", "list@remove",
Expand Down Expand Up @@ -403,6 +404,20 @@ struct AttributeHandler {
{ throw SemanticError(msg, loc); });
}

static ASR::asr_t* eval_dict_values(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
Vec<ASR::expr_t*> args_with_dict;
args_with_dict.reserve(al, args.size() + 1);
args_with_dict.push_back(al, s);
for(size_t i = 0; i < args.size(); i++) {
args_with_dict.push_back(al, args[i]);
}
ASRUtils::create_intrinsic_function create_function =
ASRUtils::IntrinsicFunctionRegistry::get_create_function("dict.values");
return create_function(al, loc, args_with_dict, [&](const std::string &msg, const Location &loc)
{ throw SemanticError(msg, loc); });
}

}; // AttributeHandler

} // namespace LCompilers::LPython
Expand Down