Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Aqua tests #775

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
InteractiveUtils = "1"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
Expand Down
5 changes: 4 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@
```
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
function unwrap_right_left_vns(right::NamedDist, left::AbstractArray, ::VarName)
return unwrap_right_left_vns(right.dist, left, right.name)

Check warning on line 224 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L223-L224

Added lines #L223 - L224 were not covered by tests
end
function unwrap_right_left_vns(right::NamedDist, left::AbstractMatrix, ::VarName)

Check warning on line 226 in src/compiler.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler.jl#L226

Added line #L226 was not covered by tests
return unwrap_right_left_vns(right.dist, left, right.name)
end
function unwrap_right_left_vns(
Expand Down
22 changes: 14 additions & 8 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@
return left, acclogp_observe!!(context, vi, logp)
end

function assume(rng, spl::Sampler, dist)
function assume(rng::Random.AbstractRNG, spl::Sampler, dist)

Check warning on line 198 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L198

Added line #L198 was not covered by tests
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

Expand Down Expand Up @@ -291,14 +291,18 @@
function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi)
return dot_assume(right, left, vns, vi)
end
function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi)
function dot_tilde_assume(

Check warning on line 294 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L294

Added line #L294 was not covered by tests
::IsLeaf, rng::Random.AbstractRNG, ::AbstractContext, sampler, right, left, vns, vi
)
return dot_assume(rng, sampler, right, vns, left, vi)
end

function dot_tilde_assume(::IsParent, context::AbstractContext, args...)
return dot_tilde_assume(childcontext(context), args...)
end
function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...)
function dot_tilde_assume(

Check warning on line 303 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L303

Added line #L303 was not covered by tests
::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args...
)
return dot_tilde_assume(rng, childcontext(context), args...)
end

Expand Down Expand Up @@ -371,7 +375,7 @@
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dist::MultivariateDistribution,
vns::AbstractVector{<:VarName},
Expand Down Expand Up @@ -404,7 +408,7 @@
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vns::AbstractArray{<:VarName},
Expand All @@ -416,7 +420,9 @@
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans.((vi,), vns)))
return r, lp, vi
end
function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any)
function dot_assume(

Check warning on line 423 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L423

Added line #L423 was not covered by tests
rng::Random.AbstractRNG, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any
)
return error(
"[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing assume statement"
)
Expand All @@ -436,7 +442,7 @@
end

function get_and_set_val!(
rng,
rng::Random.AbstractRNG,
vi::VarInfoOrThreadSafeVarInfo,
vns::AbstractVector{<:VarName},
dist::MultivariateDistribution,
Expand Down Expand Up @@ -478,7 +484,7 @@
end

function get_and_set_val!(
rng,
rng::Random.AbstractRNG,
vi::VarInfoOrThreadSafeVarInfo,
vns::AbstractArray{<:VarName},
dists::Union{Distribution,AbstractArray{<:Distribution}},
Expand Down
2 changes: 2 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@
end
# Optimisation when there are no values to condition on
ConditionContext(::NamedTuple{()}, context::AbstractContext) = context
# Same as above, and avoids method ambiguity with below
ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context

Check warning on line 340 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L340

Added line #L340 was not covered by tests
# Collapse consecutive levels of `ConditionContext`. Note that this overrides
# values inside the child context, thus giving precedence to the outermost
# `ConditionContext`.
Expand Down
4 changes: 4 additions & 0 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
Base.size(dist::NamedDist) = Base.size(dist.dist)

Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real,0})

Check warning on line 20 in src/distribution_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/distribution_wrappers.jl#L20

Added line #L20 was not covered by tests
# extract the singleton value from 0-dimensional array
return Distributions.logpdf(dist.dist, first(x))

Check warning on line 22 in src/distribution_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/distribution_wrappers.jl#L22

Added line #L22 was not covered by tests
end
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
end
Expand Down
14 changes: 10 additions & 4 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,14 @@
end

# Constructor from `Model`.
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
function SimpleVarInfo{T}(model::Model, args...) where {T<:Real}
function SimpleVarInfo(

Check warning on line 235 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L235

Added line #L235 was not covered by tests
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
)
return SimpleVarInfo{Float64}(model, args...)

Check warning on line 238 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L238

Added line #L238 was not covered by tests
end
function SimpleVarInfo{T}(

Check warning on line 240 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L240

Added line #L240 was not covered by tests
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
) where {T<:Real}
return last(evaluate!!(model, SimpleVarInfo{T}(), args...))
end

Expand Down Expand Up @@ -497,7 +503,7 @@
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dists::Union{Distribution,AbstractArray{<:Distribution}},
vns::AbstractArray{<:VarName},
Expand All @@ -523,7 +529,7 @@
end

function dot_assume(
rng,
rng::Random.AbstractRNG,
spl::Union{SampleFromPrior,SampleFromUniform},
dist::MultivariateDistribution,
vns::AbstractVector{<:VarName},
Expand Down
6 changes: 5 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@
)
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
function VarInfo(

Check warning on line 209 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L209

Added line #L209 was not covered by tests
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
)
return VarInfo(Random.default_rng(), model, args...)

Check warning on line 212 in src/varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/varinfo.jl#L212

Added line #L212 was not covered by tests
end

"""
vector_length(varinfo::VarInfo)
Expand Down
3 changes: 0 additions & 3 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,3 @@ Possibly existing indices of `varname` are neglected.
) where {s,missings,_F,_a,_T}
return s in missings
end

# HACK: Type-piracy. Is this really the way to go?
AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym
8 changes: 8 additions & 0 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module AquaTests

using Aqua: Aqua
using DynamicPPL

Aqua.test_all(DynamicPPL)

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
99 changes: 50 additions & 49 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,57 +45,58 @@ include("test_util.jl")
# groups are chosen to make both groups take roughly the same amount of
# time, but beyond that there is no particular reason for the split.
if GROUP == "All" || GROUP == "Group1"
include("utils.jl")
include("compiler.jl")
include("varnamedvector.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("model.jl")
include("sampler.jl")
include("independence.jl")
include("distribution_wrappers.jl")
include("logdensityfunction.jl")
include("linking.jl")
include("serialization.jl")
include("pointwise_logdensities.jl")
include("lkj.jl")
include("deprecated.jl")
include("Aqua.jl")
# include("utils.jl")
# include("compiler.jl")
# include("varnamedvector.jl")
# include("varinfo.jl")
# include("simple_varinfo.jl")
# include("model.jl")
# include("sampler.jl")
# include("independence.jl")
# include("distribution_wrappers.jl")
# include("logdensityfunction.jl")
# include("linking.jl")
# include("serialization.jl")
# include("pointwise_logdensities.jl")
# include("lkj.jl")
# include("deprecated.jl")
end

if GROUP == "All" || GROUP == "Group2"
include("contexts.jl")
include("context_implementations.jl")
include("threadsafe.jl")
include("debug_utils.jl")
@testset "compat" begin
include(joinpath("compat", "ad.jl"))
end
@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
@testset "prob and logprob macro" begin
@test_throws ErrorException prob"..."
@test_throws ErrorException logprob"..."
end
@testset "doctests" begin
DocMeta.setdocmeta!(
DynamicPPL,
:DocTestSetup,
:(using DynamicPPL, Distributions);
recursive=true,
)
doctestfilters = [
# Ignore the source of a warning in the doctest output, since this is dependent on host.
# This is a line that starts with "└ @ " and ends with the line number.
r"└ @ .+:[0-9]+",
]
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
end
# include("contexts.jl")
# include("context_implementations.jl")
# include("threadsafe.jl")
# include("debug_utils.jl")
# @testset "compat" begin
# include(joinpath("compat", "ad.jl"))
# end
# @testset "extensions" begin
# include("ext/DynamicPPLMCMCChainsExt.jl")
# include("ext/DynamicPPLJETExt.jl")
# end
# @testset "ad" begin
# include("ext/DynamicPPLForwardDiffExt.jl")
# include("ext/DynamicPPLMooncakeExt.jl")
# include("ad.jl")
# end
# @testset "prob and logprob macro" begin
# @test_throws ErrorException prob"..."
# @test_throws ErrorException logprob"..."
# end
# @testset "doctests" begin
# DocMeta.setdocmeta!(
# DynamicPPL,
# :DocTestSetup,
# :(using DynamicPPL, Distributions);
# recursive=true,
# )
# doctestfilters = [
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
# # This is a line that starts with "└ @ " and ends with the line number.
# r"└ @ .+:[0-9]+",
# ]
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
# end
end
end
Loading