diff --git a/src/MultivariateBases.jl b/src/MultivariateBases.jl index e80ba8c..8e6f94b 100644 --- a/src/MultivariateBases.jl +++ b/src/MultivariateBases.jl @@ -32,5 +32,6 @@ include("hermite.jl") include("laguerre.jl") include("legendre.jl") include("chebyshev.jl") +include("quotient.jl") end # module diff --git a/src/monomial.jl b/src/monomial.jl index 35078fa..6abc7bc 100644 --- a/src/monomial.jl +++ b/src/monomial.jl @@ -265,12 +265,15 @@ function MP.polynomial(Q::AbstractMatrix, mb::SubBasis{Monomial}, T::Type) return MP.polynomial(Q, mb.monomials, T) end -function MP.coefficients(p, basis::SubBasis{Monomial}) +function MP.coefficients( + p::MP.AbstractPolynomialLike, + basis::SubBasis{Monomial}, +) return MP.coefficients(p, basis.monomials) end -function MP.coefficients(p, ::FullBasis{Monomial}) - return MP.coefficients(p) +function MP.coefficients(p::MP.AbstractPolynomialLike, ::FullBasis{Monomial}) + return p end # Overload some of the `MP` interface for convenience diff --git a/src/quotient.jl b/src/quotient.jl new file mode 100644 index 0000000..48c970b --- /dev/null +++ b/src/quotient.jl @@ -0,0 +1,10 @@ +struct QuotientBasis{T,I,B<:SA.AbstractBasis{T,I},D} <: SA.ExplicitBasis{T,I} + basis::B + divisor::D +end + +Base.length(basis::QuotientBasis) = length(basis.basis) + +function MP.coefficients(p, basis::QuotientBasis) + return MP.coefficients(rem(p, basis.divisor), basis.basis) +end diff --git a/test/quotient.jl b/test/quotient.jl new file mode 100644 index 0000000..417edf3 --- /dev/null +++ b/test/quotient.jl @@ -0,0 +1,19 @@ +struct PlusMinusOne # Reduce modulo `x^2 = 1` +end +function Base.rem(m::AbstractMonomial, ::PlusMinusOne) + return prod(v^mod(e, 2) for (v, e) in powers(m)) +end +function Base.rem(t::AbstractTerm, d::PlusMinusOne) + return coefficient(t) * rem(monomial(t), d) +end +function Base.rem(p::AbstractPolynomial, d::PlusMinusOne) + return sum(rem(t, d) for t in terms(p)) +end + +@testset "PlusMinusOne" begin + @polyvar x y + basis = + MB.QuotientBasis(MB.SubBasis{MB.Monomial}([1, y, x]), PlusMinusOne()) + @test length(basis) == 3 + @test coefficients(x^3 - 2x^2 * y + 3x^2, basis) == [3, -2, 1] +end diff --git a/test/runtests.jl b/test/runtests.jl index 445e878..984dce0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -147,3 +147,6 @@ end @testset "Chebyshev" begin include("chebyshev.jl") end +@testset "Quotient" begin + include("quotient.jl") +end