Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support and testing for equality-atom and if-then-else based RDDL-style boolean interoperability #120

Draft
wants to merge 3 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 137 additions & 23 deletions src/tarski/io/rddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .common import load_tpl
from ..fol import FirstOrderLanguage
from ..syntax import implies, land, lor, neg, Connective, Quantifier, CompoundTerm, Interval, Atom, IfThenElse, \
Contradiction, Tautology, CompoundFormula, forall, ite, AggregateCompoundTerm, QuantifiedFormula, Term, Function, \
Contradiction, Tautology, CompoundFormula, forall, exists, ite, AggregateCompoundTerm, QuantifiedFormula, Term, Function, \
Variable, Predicate, Constant, Formula, builtins
from ..syntax import arithmetic as tm
from ..syntax.temporal import ltl as tt
Expand All @@ -17,6 +17,11 @@
from ..errors import LanguageError
from ..theories import Theory, language

def is_boolean_constant_equal_to_true(expr):
if isinstance(expr, Constant):
if Theory.BOOLEAN in expr.language.theories:
return expr.sort == expr.language.Boolean and expr.is_syntactically_equal(expr.language.constant(1, expr.language.Boolean))
return False

class TranslationError(Exception):
pass
Expand All @@ -25,6 +30,7 @@ class TranslationError(Exception):
logic_rddl_to_tarski = {
'=>': implies,
'^': land,
'&': land,
'|': lor,
'~': neg}

Expand Down Expand Up @@ -149,9 +155,15 @@ def translate_expression(lang, rddl_expr):
prod_expr = ite(prod_expr, Constant(1, lang.Integer), Constant(0, lang.Integer))
return tm.product(var, prod_expr)
elif expr_sym == 'forall':
var = translate_expression(lang, rddl_expr.args[0])
forall_expr = translate_expression(lang, rddl_expr.args[1])
return forall(var, forall_expr)
vars = [translate_expression(lang, a) for a in rddl_expr.args[:-1]]
#var = translate_expression(lang, rddl_expr.args[0])
forall_expr = translate_expression(lang, rddl_expr.args[-1])
return forall(*(vars + [forall_expr]))
elif expr_sym == 'exists':
vars = [translate_expression(lang, a) for a in rddl_expr.args[:-1]]
#var = translate_expression(lang, rddl_expr.args[0])
exists_expr = translate_expression(lang, rddl_expr.args[-1])
return exists(*(vars + [exists_expr]))
elif expr_type == 'arithmetic':
op = arithmetic_rddl_to_tarski[expr_sym]
targs = [lang] + [translate_expression(lang, arg) for arg in rddl_expr.args]
Expand Down Expand Up @@ -211,9 +223,12 @@ class Reader:
that specify a RDDL task
"""

def __init__(self, filename):
def __init__(self, domain_filename, inst_filename = None):
self.language = None
self.rddl_model = self._load_rddl_model(filename)
if inst_filename is None:
self.rddl_model = self._load_rddl_model(domain_filename)
else:
self.rddl_model = self._load_rddl_model_ippc(domain_filename, inst_filename)
self.parameters = Parameters()
self.x0 = None

Expand All @@ -226,6 +241,22 @@ def _load_rddl_model(filename):
# parse RDDL
return parser.parse(rddl)

@staticmethod
def _load_rddl_model_ippc(dom_filename, inst_filename):
with open(dom_filename, 'r') as input_file:
dom_text = input_file.read()
with open(inst_filename, 'r') as input_file:
inst_text = input_file.read()
full_text = '\n\n'.join([dom_text, inst_text])
# MRJ: for debug purposes
#for k, l in enumerate(full_text.split('\n')):
# print(k, l)
parser = modules.import_pyrddl_parser()()
parser.debugging = True
parser.build()
# parse RDDL
return parser.parse(full_text)

def _translate_types(self):
for typename, parent_type in self.rddl_model.domain.types:
assert parent_type == 'object'
Expand Down Expand Up @@ -262,9 +293,11 @@ def translate_rddl_model(self):
# 3. acquire instance parameters
self.parameters.horizon = self.rddl_model.instance.horizon
self.parameters.discount = self.rddl_model.instance.discount
if self.rddl_model.instance.max_nondef_actions != 'pos-inf':
self.parameters.max_actions = self.rddl_model.instance.max_nondef_actions

try:
if self.rddl_model.instance.max_nondef_actions != 'pos-inf':
self.parameters.max_actions = self.rddl_model.instance.max_nondef_actions
except AttributeError:
pass
# 4. recover initial state, interpretation of fluents
self.x0 = Model(self.language)
self.x0.evaluator = evaluate
Expand All @@ -283,8 +316,8 @@ def translate_rddl_model(self):
self.x0.add(expr.predicate, *expr.subterms)


built_in_type_map = {'object': 'Object', 'real': 'Real', 'int': 'Integer'}
reverse_built_in_type_map = {'Object': 'object', 'Real': 'real', 'Integer': 'int'}
built_in_type_map = {'object': 'Object', 'real': 'Real', 'int': 'Integer', 'bool' : 'Boolean'}
reverse_built_in_type_map = {'Object': 'object', 'Real': 'real', 'Integer': 'int', 'Boolean': 'bool'}


def translate_builtin_type(lang: FirstOrderLanguage, name):
Expand Down Expand Up @@ -317,6 +350,7 @@ class Requirements(Enum):
CONTINUOUS = "continuous"
MULTIVALUED = "multivalued"
REWARD_DET = "reward-deterministic"
PRECONDITIONS = "preconditions"
INTERMEDIATE_NODES = "intermediate-nodes"
PARTIALLY_OBS = "partially-observed"
CONCURRENT = "concurrent"
Expand Down Expand Up @@ -415,7 +449,33 @@ def __init__(self, task):
self.non_fluent_signatures = set()
self.interm_signatures = set()

def write_model(self, filename):
self.bool_t = self.task.L.constant(1, self.task.L.Boolean)
self.bool_f = self.task.L.constant(0, self.task.L.Boolean)

def rddl_2018_format(self):
tpl = load_tpl("rddl_model_2018.tpl")
domain_content = tpl.format(
domain_name=self.task.domain_name,
req_list=self.get_requirements(),
type_list=self.get_types(),
pvar_list=self.get_pvars(),
cpfs_list=self.get_cpfs(),
reward_expr=self.get_reward(),
action_precondition_list=self.get_preconditions(),
state_invariant_list=self.get_state_invariants(),
domain_non_fluents='{}_non_fluents'.format(self.task.instance_name),
object_list=self.get_objects(),
non_fluent_expr=self.get_non_fluent_init(),
instance_name=self.task.instance_name,
init_state_fluent_expr=self.get_state_fluent_init(),
non_fluents_ref='{}_non_fluents'.format(self.task.instance_name),
max_nondef_actions=self.get_max_nondef_actions(),
horizon=self.get_horizon(),
discount=self.get_discount()
)
return domain_content

def rddl_pre_2018_format(self):
tpl = load_tpl("rddl_model.tpl")
content = tpl.format(
domain_name=self.task.domain_name,
Expand All @@ -436,8 +496,22 @@ def write_model(self, filename):
horizon=self.get_horizon(),
discount=self.get_discount()
)
return content

def write_model(self, filename, format_2018_style=False):
with open(filename, 'w') as file:
if format_2018_style:
content = self.rddl_2018_format()
else:
content = self.rddl_pre_2018_format()
file.write(content)
self.reset()

def reset(self):
self.need_obj_decl = []
self.need_constraints = {}
self.non_fluent_signatures = set()
self.interm_signatures = set()

def get_requirements(self):
return ', '.join([str(r) for r in self.task.requirements])
Expand Down Expand Up @@ -480,12 +554,22 @@ def get_signature(fl):
return '{}'.format(head)
return '{}({})'.format(head, ','.join(domain))

def get_valstring(self, fluent, value):
if self.get_type(fluent) == 'bool':
if value == 0:
return "false"
elif value == 1:
return "true"
else:
assert(False)
return str(value)

def get_pvars(self):
pvar_decl_list = []
# state fluents
for fl, v in self.task.state_fluents:
rsig = self.get_signature(fl)
pvar_decl_list += ['\t{} : {{state-fluent, {}, default = {}}};'.format(rsig, self.get_type(fl), str(v))]
pvar_decl_list += ['\t{} : {{state-fluent, {}, default = {}}};'.format(rsig, self.get_type(fl), self.get_valstring(fl, v))]
for fl, level in self.task.interm_fluents:
rsig = self.get_signature(fl)
try:
Expand All @@ -495,14 +579,14 @@ def get_pvars(self):
pvar_decl_list += ['\t{} : {{interm-fluent, {}, level = {}}};'.format(rsig, self.get_type(fl), str(level))]
for fl, v in self.task.action_fluents:
rsig = self.get_signature(fl)
pvar_decl_list += ['\t{} : {{action-fluent, {}, default = {}}};'.format(rsig, self.get_type(fl), str(v))]
pvar_decl_list += ['\t{} : {{action-fluent, {}, default = {}}};'.format(rsig, self.get_type(fl), self.get_valstring(fl,v))]
for fl, v in self.task.non_fluents:
rsig = self.get_signature(fl)
try:
self.non_fluent_signatures.add(fl.symbol.signature)
except AttributeError:
self.non_fluent_signatures.add(fl.predicate.signature)
pvar_decl_list += ['\t{} : {{non-fluent, {}, default = {}}};'.format(rsig, self.get_type(fl), str(v))]
pvar_decl_list += ['\t{} : {{non-fluent, {}, default = {}}};'.format(rsig, self.get_type(fl), self.get_valstring(fl,v))]
return '\n'.join(pvar_decl_list)

def get_cpfs(self):
Expand Down Expand Up @@ -552,6 +636,11 @@ def get_non_fluent_init(self):
term_str = signature[0]
else:
term_str = str(self.task.L.get(signature[0])(*subterms))
if signature[-1] == "Boolean":
if value.is_syntactically_equal(self.bool_f):
value = "false"
elif value.is_syntactically_equal(self.bool_t):
value = "true"
non_fluent_init_list += ['\t{} = {};'.format(term_str, value)]
for signature, defs in self.task.x0.predicate_extensions.items():
if signature not in self.non_fluent_signatures:
Expand Down Expand Up @@ -581,6 +670,11 @@ def get_state_fluent_init(self):
term_str = signature[0]
else:
term_str = str(self.task.L.get(signature[0])(*subterms))
if signature[-1] == "Boolean":
if value.is_syntactically_equal(self.bool_f):
value = "false"
elif value.is_syntactically_equal(self.bool_t):
value = "true"
init_list += ['\t{} = {};'.format(term_str, value)]
for signature, defs in self.task.x0.predicate_extensions.items():
if signature in self.non_fluent_signatures \
Expand Down Expand Up @@ -610,7 +704,9 @@ def rewrite(self, expr):
if len(re_st) > 0:
# MRJ: Random variables need parenthesis, other functions need
# brackets...
if expr.symbol.symbol in builtins.get_random_binary_functions():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this one was a bit of a puzzle back in the day. Nice improvement.

relevant_functions = builtins.get_random_binary_functions()
relevant_functions += builtins.get_random_unary_functions()
if expr.symbol.symbol in relevant_functions:
st_str = '({})'.format(','.join(re_st))
else:
st_str = '[{}]'.format(','.join(re_st))
Expand All @@ -620,7 +716,14 @@ def rewrite(self, expr):
return '{}{}'.format(expr.symbol.signature[0], st_str)
elif isinstance(expr, Atom):
re_st = [self.rewrite(st) for st in expr.subterms]
st = expr.subterms
if expr.predicate.builtin:
#check to see if we have a equality atom wrapping a Boolean codomain function to a literal 1 constant
if expr.predicate.symbol == BPS.EQ:
if is_boolean_constant_equal_to_true(st[0]) and isinstance(st[1], CompoundTerm) and st[1].codomain == st[1].language.Boolean:
return re_st[1]
elif is_boolean_constant_equal_to_true(st[1]) and isinstance(st[0], CompoundTerm) and st[0].sort == st[0].language.Boolean:
return re_st[0]
if expr.predicate.symbol in symbol_map.keys():
return '({} {} {})'.format(re_st[0], symbol_map[expr.predicate.symbol], re_st[1])
st_str = ''
Expand All @@ -633,10 +736,21 @@ def rewrite(self, expr):
elif isinstance(expr, Constant):
return str(expr)
elif isinstance(expr, IfThenElse):
cond = self.rewrite(expr.condition)
expr1 = self.rewrite(expr.subterms[0])
expr2 = self.rewrite(expr.subterms[1])
return 'if ({}) then ({}) else ({})'.format(cond, expr1, expr2)
#check to see if it is a conversion to a boolean type
isconversion = True
lang = expr.subterms[0].language
for index, val in [(0, 1),(1, 0)]:
st = expr.subterms[index]
if not (isinstance(st, Constant) and st.sort == lang.Boolean and st.is_syntactically_equal(lang.constant(val, lang.Boolean))):
isconversion = False
break
if isconversion:
return self.rewrite(expr.condition)
else:
cond = self.rewrite(expr.condition)
expr1 = self.rewrite(expr.subterms[0])
expr2 = self.rewrite(expr.subterms[1])
return 'if ({}) then ({}) else ({})'.format(cond, expr1, expr2)
elif isinstance(expr, Tautology):
return 'true'
elif isinstance(expr, Contradiction):
Expand All @@ -645,21 +759,21 @@ def rewrite(self, expr):
re_sf = [self.rewrite(st) for st in expr.subformulas]
re_sym = symbol_map[expr.connective]
if len(re_sf) == 1:
return '{}{}'.format(re_sym, re_sf)
return '{}({})'.format(re_sym, re_sf[0])
return '({} {} {})'.format(re_sf[0], re_sym, re_sf[1])
elif isinstance(expr, QuantifiedFormula):
re_f = self.rewrite(expr.formula)
re_vars = ['?{} : {}'.format(x.symbol, x.sort.name) for x in expr.variables]
re_sym = symbol_map[expr.quantifier]
return '{}_{{{}}} ({})'.format(re_sym, ','.join(re_vars), re_f)
return '({}_{{{}}} [{}])'.format(re_sym, ','.join(re_vars), re_f)
elif isinstance(expr, AggregateCompoundTerm):
re_expr = self.rewrite(expr.subterm)
re_vars = ['?{} : {}'.format(x.symbol, x.sort.name) for x in expr.bound_vars]
if expr.symbol == BFS.ADD:
re_sym = 'sum'
elif expr.symbol == BFS.MUL:
re_sym = 'prod'
return '{}_{{{}}} ({})'.format(re_sym, ','.join(re_vars), re_expr)
return '({}_{{{}}} [{}])'.format(re_sym, ','.join(re_vars), re_expr)
raise RuntimeError(f"Unknown expression type for '{expr}'")

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions src/tarski/syntax/arithmetic/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def normal(mu, sigma):
return np.random.normal(mu, sigma)
return normal_func(mu, sigma)

def bernoulli(p):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers for adding the "master" distribution

try:
bernoulli_func = p.language.get_function(bfs.BERNOULLI)
except AttributeError:
np = modules.import_numpy()
return np.random.random(p)
return bernoulli_func(p)


def gamma(shape, scale):
try:
Expand Down
5 changes: 2 additions & 3 deletions src/tarski/syntax/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def get_random_binary_functions():


def get_random_unary_functions():
# BFS = BuiltinFunctionSymbol
return []

BFS = BuiltinFunctionSymbol
return [BFS.BERNOULLI]

def get_predicate_from_symbol(symbol: str):
return BuiltinPredicateSymbol(symbol)
Expand Down
Loading