Skip to content

Commit

Permalink
fix: Support deriving Print for enums (dojoengine#1091)
Browse files Browse the repository at this point in the history
* skeleton for deriving print for enum

* print enum variant names

* add example enum to plugin_test_data

* cargo fmt

* remove unecessary derives

* add core and comma

* fix tests

* fix tests

* remove class hash changes

* revert Cargo.lock

* fix tests

* use OptionTypeClause

* fix: ensure print is expanded under test cfg

---------

Co-authored-by: glihm <[email protected]>
  • Loading branch information
0xicosahedron and glihm authored Jan 17, 2024
1 parent 8fb9941 commit 05e59cf
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 176 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ cairo-lang-formatter = "2.4.0"
cairo-lang-language-server = "2.4.0"
cairo-lang-lowering = "2.4.0"
cairo-lang-parser = "2.4.0"
cairo-lang-plugins = "2.4.0"
cairo-lang-plugins = { version = "2.4.0", features = [ "testing" ] }
cairo-lang-project = "2.4.0"
cairo-lang-semantic = { version = "2.4.0", features = [ "testing" ] }
cairo-lang-sierra = "2.4.0"
Expand Down
2 changes: 1 addition & 1 deletion crates/dojo-lang/src/introspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub fn handle_introspect_struct(db: &dyn SyntaxGroup, struct_ast: ItemStruct) ->
/// A handler for Dojo code derives Introspect for an enum
/// Parameters:
/// * db: The semantic database.
/// * struct_ast: The AST of the struct.
/// * enum_ast: The AST of the enum.
/// Returns:
/// * A RewriteNode containing the generated code.
pub fn handle_introspect_enum(
Expand Down
5 changes: 3 additions & 2 deletions crates/dojo-lang/src/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::inline_macros::get::GetMacro;
use crate::inline_macros::set::SetMacro;
use crate::introspect::{handle_introspect_enum, handle_introspect_struct};
use crate::model::handle_model_struct;
use crate::print::derive_print;
use crate::print::{handle_print_enum, handle_print_struct};

const DOJO_CONTRACT_ATTR: &str = "dojo::contract";

Expand Down Expand Up @@ -279,6 +279,7 @@ impl MacroPlugin for BuiltinDojoPlugin {
enum_ast.clone(),
));
}
"Print" => rewrite_nodes.push(handle_print_enum(db, enum_ast.clone())),
_ => continue,
}
}
Expand Down Expand Up @@ -355,7 +356,7 @@ impl MacroPlugin for BuiltinDojoPlugin {
diagnostics.extend(model_diagnostics);
}
"Print" => {
rewrite_nodes.push(derive_print(db, struct_ast.clone()));
rewrite_nodes.push(handle_print_struct(db, struct_ast.clone()));
}
"Introspect" => {
rewrite_nodes
Expand Down
62 changes: 2 additions & 60 deletions crates/dojo-lang/src/plugin_test.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use std::sync::Arc;

use cairo_lang_defs::db::{DefsDatabase, DefsGroup};
use cairo_lang_defs::ids::{LanguageElementId, ModuleId, ModuleItemId};
use cairo_lang_defs::ids::ModuleId;
use cairo_lang_defs::plugin::MacroPlugin;
use cairo_lang_diagnostics::{format_diagnostics, DiagnosticLocation};
use cairo_lang_filesystem::cfg::CfgSet;
use cairo_lang_filesystem::db::{
init_files_group, AsFilesGroupMut, CrateConfiguration, FilesDatabase, FilesGroup, FilesGroupEx,
};
use cairo_lang_filesystem::ids::{CrateLongId, Directory, FileLongId};
use cairo_lang_parser::db::ParserDatabase;
use cairo_lang_plugins::get_base_plugins;
use cairo_lang_plugins::test_utils::expand_module_text;
use cairo_lang_syntax::node::db::{SyntaxDatabase, SyntaxGroup};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{ast, TypedSyntaxNode};
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
use cairo_lang_test_utils::verify_diagnostics_expectation;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use cairo_lang_utils::Upcast;

use super::BuiltinDojoPlugin;
Expand Down Expand Up @@ -118,58 +115,3 @@ pub fn test_expand_plugin_inner(
error,
}
}

pub fn expand_module_text(
db: &dyn DefsGroup,
module_id: ModuleId,
diagnostics: &mut Vec<String>,
) -> String {
let mut output = String::new();
// A collection of all the use statements in the module.
let mut uses_list = UnorderedHashSet::default();
let syntax_db = db.upcast();
// Collect the module diagnostics.
for (file_id, diag) in db.module_plugin_diagnostics(module_id).unwrap().iter() {
let syntax_node = diag.stable_ptr.lookup(syntax_db);
let location = DiagnosticLocation {
file_id: file_id.file_id(db.upcast()).unwrap(),
span: syntax_node.span_without_trivia(syntax_db),
};
diagnostics.push(format_diagnostics(db.upcast(), &diag.message, location));
}
for item_id in db.module_items(module_id).unwrap().iter() {
if let ModuleItemId::Submodule(item) = item_id {
let submodule_item = item.stable_ptr(db).lookup(syntax_db);
if let ast::MaybeModuleBody::Some(body) = submodule_item.body(syntax_db) {
// Recursively expand inline submodules.
output.extend([
submodule_item.attributes(syntax_db).node.get_text(syntax_db),
submodule_item.module_kw(syntax_db).as_syntax_node().get_text(syntax_db),
submodule_item.name(syntax_db).as_syntax_node().get_text(syntax_db),
body.lbrace(syntax_db).as_syntax_node().get_text(syntax_db),
expand_module_text(db, ModuleId::Submodule(*item), diagnostics),
body.rbrace(syntax_db).as_syntax_node().get_text(syntax_db),
]);
continue;
}
} else if let ModuleItemId::Use(use_id) = item_id {
let mut use_item = use_id.stable_ptr(db).lookup(syntax_db).as_syntax_node();
// Climb up the AST until the syntax kind is ItemUse. This is needed since the use item
// points to the use leaf as one use statement can represent multiple use items.
while let Some(parent) = use_item.parent() {
use_item = parent;
if use_item.kind(syntax_db) == SyntaxKind::ItemUse {
break;
}
}
if uses_list.insert(use_item.clone()) {
output.push_str(&use_item.get_text(syntax_db));
}
continue;
}
let syntax_item = item_id.untyped_stable_ptr(db);
// Output other items as is.
output.push_str(&syntax_item.lookup(syntax_db).get_text(syntax_db));
}
output
}
157 changes: 53 additions & 104 deletions crates/dojo-lang/src/plugin_test_data/print
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
//! > test_runner_name
test_expand_plugin

//! > cfg
["test"]

//! > cairo_code
use serde::Serde;
use debug::PrintTrait;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Position {
#[key]
id: felt252,
Expand All @@ -15,14 +19,14 @@ struct Position {
y: felt252
}

#[derive(Print, Serde)]
#[derive(Print)]
struct Roles {
role_ids: Array<u8>
}

use starknet::ContractAddress;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Player {
#[key]
game: felt252,
Expand All @@ -32,11 +36,18 @@ struct Player {
name: felt252,
}

//! > generated_cairo_code
use serde::Serde;
#[derive(Print)]
enum Enemy {
Unknown,
Bot: felt252,
OtherPlayer: ContractAddress,
}

//! > expanded_cairo_code
use serde::Serde;
use debug::PrintTrait;

#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Position {
#[key]
id: felt252,
Expand All @@ -45,37 +56,14 @@ struct Position {
y: felt252
}

#[cfg(test)]
impl PositionPrintImpl of core::debug::PrintTrait<Position> {
fn print(self: Position) {
core::debug::PrintTrait::print('id');
core::debug::PrintTrait::print(self.id);
core::debug::PrintTrait::print('x');
core::debug::PrintTrait::print(self.x);
core::debug::PrintTrait::print('y');
core::debug::PrintTrait::print(self.y);
}
}


#[derive(Print, Serde)]
#[derive(Print)]
struct Roles {
role_ids: Array<u8>
}

#[cfg(test)]
impl RolesPrintImpl of core::debug::PrintTrait<Roles> {
fn print(self: Roles) {
core::debug::PrintTrait::print('role_ids');
core::debug::PrintTrait::print(self.role_ids);
}
}


use starknet::ContractAddress;


#[derive(Print, Copy, Drop, Serde)]
#[derive(Print)]
struct Player {
#[key]
game: felt252,
Expand All @@ -84,87 +72,48 @@ struct Player {

name: felt252,
}
#[cfg(test)]
impl PlayerPrintImpl of core::debug::PrintTrait<Player> {
fn print(self: Player) {
core::debug::PrintTrait::print('game');
core::debug::PrintTrait::print(self.game);
core::debug::PrintTrait::print('player');
core::debug::PrintTrait::print(self.player);
core::debug::PrintTrait::print('name');
core::debug::PrintTrait::print(self.name);
}
}

//! > expected_diagnostics

//! > expanded_cairo_code
use serde::Serde;

#[derive(Print, Copy, Drop, Serde)]
struct Position {
#[key]
id: felt252,

x: felt252,
y: felt252
#[derive(Print)]
enum Enemy {
Unknown,
Bot: felt252,
OtherPlayer: ContractAddress,
}

#[derive(Print, Serde)]
struct Roles {
role_ids: Array<u8>
#[cfg(test)]
impl PositionStructPrintImpl of core::debug::PrintTrait<Position> {
fn print(self: Position) {
core::debug::PrintTrait::print('id'); core::debug::PrintTrait::print(self.id);
core::debug::PrintTrait::print('x'); core::debug::PrintTrait::print(self.x);
core::debug::PrintTrait::print('y'); core::debug::PrintTrait::print(self.y);
}
}

use starknet::ContractAddress;

#[derive(Print, Copy, Drop, Serde)]
struct Player {
#[key]
game: felt252,
#[key]
player: ContractAddress,

name: felt252,
}
impl PositionCopy of core::traits::Copy::<Position>;
impl PositionDrop of core::traits::Drop::<Position>;
impl PositionSerde of core::serde::Serde::<Position> {
fn serialize(self: @Position, ref output: core::array::Array<felt252>) {
core::serde::Serde::serialize(self.id, ref output);
core::serde::Serde::serialize(self.x, ref output);
core::serde::Serde::serialize(self.y, ref output)
}
fn deserialize(ref serialized: core::array::Span<felt252>) -> core::option::Option<Position> {
core::option::Option::Some(Position {
id: core::serde::Serde::deserialize(ref serialized)?,
x: core::serde::Serde::deserialize(ref serialized)?,
y: core::serde::Serde::deserialize(ref serialized)?,
})
#[cfg(test)]
impl RolesStructPrintImpl of core::debug::PrintTrait<Roles> {
fn print(self: Roles) {
core::debug::PrintTrait::print('role_ids'); core::debug::PrintTrait::print(self.role_ids);
}
}
impl RolesSerde of core::serde::Serde::<Roles> {
fn serialize(self: @Roles, ref output: core::array::Array<felt252>) {
core::serde::Serde::serialize(self.role_ids, ref output)
}
fn deserialize(ref serialized: core::array::Span<felt252>) -> core::option::Option<Roles> {
core::option::Option::Some(Roles {
role_ids: core::serde::Serde::deserialize(ref serialized)?,
})

#[cfg(test)]
impl PlayerStructPrintImpl of core::debug::PrintTrait<Player> {
fn print(self: Player) {
core::debug::PrintTrait::print('game'); core::debug::PrintTrait::print(self.game);
core::debug::PrintTrait::print('player'); core::debug::PrintTrait::print(self.player);
core::debug::PrintTrait::print('name'); core::debug::PrintTrait::print(self.name);
}
}
impl PlayerCopy of core::traits::Copy::<Player>;
impl PlayerDrop of core::traits::Drop::<Player>;
impl PlayerSerde of core::serde::Serde::<Player> {
fn serialize(self: @Player, ref output: core::array::Array<felt252>) {
core::serde::Serde::serialize(self.game, ref output);
core::serde::Serde::serialize(self.player, ref output);
core::serde::Serde::serialize(self.name, ref output)
}
fn deserialize(ref serialized: core::array::Span<felt252>) -> core::option::Option<Player> {
core::option::Option::Some(Player {
game: core::serde::Serde::deserialize(ref serialized)?,
player: core::serde::Serde::deserialize(ref serialized)?,
name: core::serde::Serde::deserialize(ref serialized)?,
})

#[cfg(test)]
impl EnemyEnumPrintImpl of core::debug::PrintTrait<Enemy> {
fn print(self: Enemy) {
match self {
Enemy::Unknown => { core::debug::PrintTrait::print('Unknown'); },
Enemy::Bot(v) => { core::debug::PrintTrait::print('Bot'); core::debug::PrintTrait::print(v); },
Enemy::OtherPlayer(v) => { core::debug::PrintTrait::print('OtherPlayer'); core::debug::PrintTrait::print(v); }
}
}
}

//! > expected_diagnostics
Loading

0 comments on commit 05e59cf

Please sign in to comment.