Skip to content

Commit 7f1fa96

Browse files
committed
Use type of custom-defined comparison in synthesized methods
1 parent 76f3b0d commit 7f1fa96

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

pyrefly/lib/alt/class/total_ordering.rs

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,11 @@ use crate::alt::answers::AnswersSolver;
55
use crate::alt::answers::LookupAnswer;
66
use crate::alt::types::class_metadata::ClassSynthesizedField;
77
use crate::alt::types::class_metadata::ClassSynthesizedFields;
8+
use crate::binding::binding::KeyClassField;
89
use crate::dunder;
910
use crate::error::collector::ErrorCollector;
1011
use crate::error::kind::ErrorKind;
11-
use crate::types::callable::Callable;
12-
use crate::types::callable::FuncMetadata;
13-
use crate::types::callable::Function;
14-
use crate::types::callable::Param;
15-
use crate::types::callable::ParamList;
16-
use crate::types::callable::Params;
17-
use crate::types::callable::Required;
1812
use crate::types::class::Class;
19-
use crate::types::types::Type;
2013

2114
// https://github.com/python/cpython/blob/a8ec511900d0d84cffbb4ee6419c9a790d131129/Lib/functools.py#L173
2215
// conversion order of rich comparison methods:
@@ -42,23 +35,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
4235
for other_cmp in conversion_order {
4336
let other_cmp_field = cls.fields().find(|f| **f == *other_cmp);
4437
if other_cmp_field.is_some() {
45-
// FIXME: We should use the type from `other_cmp_field` instead of `cls_type`.
46-
// However, here we use the type of the class itself, which is not always correct.
47-
let cls_type = self.instantiate(cls);
48-
let self_param = self.class_self_param(cls, false);
49-
let other_param =
50-
Param::Pos(Name::new_static("other"), cls_type, Required::Required);
51-
return ClassSynthesizedField::new(Type::Function(Box::new(Function {
52-
signature: Callable {
53-
params: Params::List(ParamList::new(vec![self_param, other_param])),
54-
ret: self.stdlib.bool().clone().to_type(),
55-
},
56-
metadata: FuncMetadata::def(
57-
self.module_info().name(),
58-
cls.name().clone(),
59-
cmp.clone(),
60-
),
61-
})));
38+
let other_cmp_field =
39+
self.get_from_class(cls, &KeyClassField(cls.index(), other_cmp.clone()));
40+
let ty = other_cmp_field.as_named_tuple_type();
41+
return ClassSynthesizedField::new(ty);
6242
}
6343
}
6444
unreachable!("No rich comparison method found for {}", cmp);

pyrefly/lib/test/decorators.rs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ b = A(x=2)
346346
# This should give the correct type for the method `__lt__`
347347
reveal_type(A.__lt__) # E: revealed type: (self: Self@A, other: A) -> bool
348348
# This should give be synthesized via `functools.total_ordering`
349-
reveal_type(A.__gt__) # E: revealed type: (self: A, other: A) -> bool
349+
reveal_type(A.__gt__) # E: revealed type: (self: Self@A, other: A) -> bool
350350
a <= b
351351
"#,
352352
);
@@ -383,7 +383,38 @@ b = A(x=2)
383383
# This should give the correct type for the method `__lt__`
384384
reveal_type(A.__lt__) # E: revealed type: (self: Self@A, other: A) -> bool
385385
# This should give be synthesized via `functools.total_ordering`
386-
reveal_type(A.__gt__) # E: revealed type: (self: A, other: A) -> bool
386+
reveal_type(A.__gt__) # E: revealed type: (self: Self@A, other: A) -> bool
387387
a <= b
388388
"#,
389389
);
390+
391+
testcase!(
392+
test_total_ordering_precedence,
393+
r#"
394+
from functools import total_ordering
395+
from typing import reveal_type
396+
397+
@total_ordering
398+
class A:
399+
def __init__(self, x: int) -> None:
400+
self.x = x
401+
def __eq__(self, other: "A") -> bool:
402+
return self.x == other.x
403+
def __lt__(self, other: "A") -> bool:
404+
return self.x < other.x
405+
def __le__(self, other: object) -> bool:
406+
if not isinstance(other, A):
407+
return NotImplemented
408+
return self.x <= other.x
409+
410+
# This should give the correct type for the method `__lt__`
411+
reveal_type(A.__lt__) # E: revealed type: (self: Self@A, other: A) -> bool
412+
# This should give be synthesized via `functools.total_ordering`
413+
reveal_type(A.__gt__) # E: revealed type: (self: Self@A, other: A) -> bool
414+
415+
# This should give the correct type for the method `__le__`
416+
reveal_type(A.__le__) # E: revealed type: (self: Self@A, other: object) -> bool
417+
# This should give be synthesized via `functools.total_ordering`
418+
reveal_type(A.__ge__) # E: revealed type: (self: Self@A, other: object) -> bool
419+
"#,
420+
);

0 commit comments

Comments
 (0)