diff --git a/src/Core/Py.jl b/src/Core/Py.jl index 0705bded..1e26ed04 100644 --- a/src/Core/Py.jl +++ b/src/Core/Py.jl @@ -125,10 +125,25 @@ Py(x::Date) = pydate(x) Py(x::Time) = pytime(x) Py(x::DateTime) = pydatetime(x) +function _py_on_main_thread_if_needed(f) + if C.PyGILState_Check() == 1 + f() + else + C.on_main_thread(f) + end +end + Base.string(x::Py) = pyisnull(x) ? "" : pystr(String, x) Base.print(io::IO, x::Py) = print(io, string(x)) function Base.show(io::IO, x::Py) + _py_on_main_thread_if_needed() do + _show(io, x) + nothing + end +end + +function _show(io::IO, x::Py) if get(io, :typeinfo, Any) == Py if pyisnull(x) print(io, "NULL") @@ -292,13 +307,9 @@ function _propertynames(x::Py, private::Bool) end function Base.propertynames(x::Py, private::Bool = false) - if C.PyGILState_Check() == 1 + _py_on_main_thread_if_needed() do _propertynames(x, private) - else - C.on_main_thread() do - _propertynames(x, private) - end::Vector{Symbol} - end + end::Vector{Symbol} end Base.Bool(x::Py) = pytruth(x) diff --git a/src/Wrap/PyDict.jl b/src/Wrap/PyDict.jl index 7ea1ba1b..7d1abd52 100644 --- a/src/Wrap/PyDict.jl +++ b/src/Wrap/PyDict.jl @@ -24,7 +24,19 @@ function Base.iterate(x::PyDict{K,V}, it::Py = pyiter(x)) where {K,V} return (k => v, it) end -function Base.iterate(x::Base.KeySet{K,PyDict{K,V}}, it::Py = pyiter(x.dict)) where {K,V} +function Base.iterate(x::Base.KeySet{K,<:PyDict{K}})::Union{Nothing,Tuple{K,Py}} where {K} + _py_on_main_thread_if_needed() do + _iterate(x, pyiter(x.dict)) + end +end + +function Base.iterate(x::Base.KeySet{K,<:PyDict{K}}, it::Py)::Union{Nothing,Tuple{K,Py}} where {K} + _py_on_main_thread_if_needed() do + _iterate(x, it) + end +end + +function _iterate(x::Base.KeySet{K,<:PyDict{K}}, it::Py) where {K} k_ = unsafe_pynext(it) pyisnull(k_) && return nothing k = pyconvert(K, k_) diff --git a/src/Wrap/Wrap.jl b/src/Wrap/Wrap.jl index d3b30ff2..3ba2c7a3 100644 --- a/src/Wrap/Wrap.jl +++ b/src/Wrap/Wrap.jl @@ -20,7 +20,7 @@ using Base: @propagate_inbounds using Tables: Tables using UnsafePointers: UnsafePtr -import ..Core: Py, ispy +import ..Core: Py, ispy, _py_on_main_thread_if_needed include("PyIterable.jl") include("PyDict.jl") diff --git a/test/Project.toml b/test/Project.toml index 3f088776..50bccf9d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/Wrap.jl b/test/Wrap.jl index c42d45ab..a5c90276 100644 --- a/test/Wrap.jl +++ b/test/Wrap.jl @@ -122,6 +122,16 @@ end @testset "iterate keys" begin @test collect(keys(z)) == ["foo"] end + @testset "complete keys from another thread" begin + using REPL + task = Threads.@spawn begin + completions, _, _ = REPL.REPLCompletions.completions("y[", 2, @__MODULE__) + length(completions) + end + wait(task) + completion_count = fetch(task) + @test completion_count == 1 + end @testset "getindex" begin @test z["foo"] === 12 @test_throws KeyError z["bar"]