Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement first-class List type #60629

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
PeriodDtype,
IntervalDtype,
DatetimeTZDtype,
ListDtype,
StringDtype,
BooleanDtype,
# missing
Expand Down Expand Up @@ -261,6 +262,7 @@
"Interval",
"IntervalDtype",
"IntervalIndex",
"ListDtype",
"MultiIndex",
"NaT",
"NamedAgg",
Expand Down
6 changes: 6 additions & 0 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
TimedeltaArray,
)
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
from pandas.core.arrays.list_ import ListDtype
from pandas.core.arrays.string_ import StringDtype
from pandas.core.indexes.api import safe_sort_index

Expand Down Expand Up @@ -824,6 +825,11 @@ def assert_extension_array_equal(
[np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined]
), "wrong missing value sentinels"

# TODO: not every array type may be convertible to NumPy; should catch here
if isinstance(left.dtype, ListDtype) and isinstance(right.dtype, ListDtype):
assert left._pa_array == right._pa_array
return

left_valid = left[~left_na].to_numpy(dtype=object)
right_valid = right[~right_na].to_numpy(dtype=object)
if check_exact:
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
UInt32Dtype,
UInt64Dtype,
)
from pandas.core.arrays.list_ import ListDtype
from pandas.core.arrays.string_ import StringDtype
from pandas.core.construction import array # noqa: ICN001
from pandas.core.flags import Flags
Expand Down Expand Up @@ -103,6 +104,7 @@
"Interval",
"IntervalDtype",
"IntervalIndex",
"ListDtype",
"MultiIndex",
"NaT",
"NamedAgg",
Expand Down
180 changes: 180 additions & 0 deletions pandas/core/arrays/list_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

from typing import (
TYPE_CHECKING,
ClassVar,
)

import numpy as np

from pandas._libs import missing as libmissing
from pandas.compat import HAS_PYARROW
from pandas.util._decorators import set_module

from pandas.core.dtypes.base import (
ExtensionDtype,
register_extension_dtype,
)
from pandas.core.dtypes.common import (
is_object_dtype,
is_string_dtype,
)

from pandas.core.arrays import ExtensionArray

if TYPE_CHECKING:
from pandas._typing import (
type_t,
Shape,
)

import pyarrow as pa


@register_extension_dtype
@set_module("pandas")
class ListDtype(ExtensionDtype):
"""
An ExtensionDtype suitable for storing homogeneous lists of data.
"""

type = list
name: ClassVar[str] = "list"

@property
def na_value(self) -> libmissing.NAType:
return libmissing.NA

@property
def kind(self) -> str:
# TODO: our extension interface says this field should be the
# NumPy type character, but no such thing exists for list
# this assumes a PyArrow large list
return "+L"

@classmethod
def construct_array_type(cls) -> type_t[ListArray]:
"""
Return the array type associated with this dtype.

Returns
-------
type
"""
return ListArray


class ListArray(ExtensionArray):
dtype = ListDtype()
__array_priority__ = 1000

def __init__(self, values: pa.Array | pa.ChunkedArray | list | ListArray) -> None:
if not HAS_PYARROW:
raise NotImplementedError("ListArray requires pyarrow to be installed")

if isinstance(values, type(self)):
self._pa_array = values._pa_array
elif not isinstance(values, pa.ChunkedArray):
# To support NA, we need to create an Array first :-(
arr = pa.array(values, from_pandas=True)
self._pa_array = pa.chunked_array(arr)
else:
self._pa_array = values

@classmethod
def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False):
if isinstance(scalars, ListArray):
return cls(scalars)
elif isinstance(scalars, pa.Scalar):
scalars = [scalars]
return cls(scalars)

try:
values = pa.array(scalars, from_pandas=True)
except TypeError:
# TypeError: object of type 'NoneType' has no len() if you have
# pa.ListScalar(None). Upstream issue in Arrow - see:
# https://github.com/apache/arrow/issues/40319
for i in range(len(scalars)):
if not scalars[i].is_valid:
scalars[i] = None

values = pa.array(scalars, from_pandas=True)
if values.type == "null":
# TODO(wayd): this is a hack to get the tests to pass, but the overall issue
# is that our extension types don't support parametrization but the pyarrow
values = pa.array(values, type=pa.list_(pa.null()))

return cls(values)

def __getitem__(self, item):
# PyArrow does not support NumPy's selection with an equal length
# mask, so let's convert those to integral positions if needed
if isinstance(item, np.ndarray) and item.dtype == bool:
pos = np.array(range(len(item)))
mask = pos[item]
return type(self)(self._pa_array.take(mask))
elif isinstance(item, int): # scalar case
return self._pa_array[item]

return type(self)(self._pa_array[item])

def __len__(self) -> int:
return len(self._pa_array)

def isna(self):
return np.array(self._pa_array.is_null())

def take(self, indexer, allow_fill=False, fill_value=None):
# TODO: what do we need to do with allow_fill and fill_value here?
return type(self)(self._pa_array.take(indexer))

@classmethod
def _empty(cls, shape: Shape, dtype: ExtensionDtype):
"""
Create an ExtensionArray with the given shape and dtype.

See also
--------
ExtensionDtype.empty
ExtensionDtype.empty is the 'official' public version of this API.
"""
# Implementer note: while ExtensionDtype.empty is the public way to
# call this method, it is still required to implement this `_empty`
# method as well (it is called internally in pandas)
if isinstance(shape, tuple):
if len(shape) > 1:
raise ValueError("ListArray may only be 1-D")
else:
length = shape[0]
else:
length = shape
return cls._from_sequence([None] * length, dtype=pa.list_(pa.null()))

def copy(self):
mm = pa.default_cpu_memory_manager()

# TODO(wayd): ChunkedArray does not implement copy_to so this
# ends up creating an Array
copied = self._pa_array.combine_chunks().copy_to(mm.device)
return type(self)(copied)

def astype(self, dtype, copy=True):
if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
if copy:
return self.copy()
return self
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# numpy has problems with astype(str) for nested elements
# and pyarrow cannot cast from list[string] to string
return np.array([str(x) for x in self._pa_array], dtype=dtype)

if not copy:
raise TypeError(f"astype from ListArray to {dtype} requires a copy")

return np.array(self._pa_array.to_pylist(), dtype=dtype, copy=copy)

@classmethod
def _concat_same_type(cls, to_concat):
data = [x._pa_array for x in to_concat]
return cls(data)
3 changes: 3 additions & 0 deletions pandas/core/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
from numpy import ma
import pyarrow as pa

from pandas._config import using_string_dtype

Expand Down Expand Up @@ -460,6 +461,8 @@ def treat_as_nested(data) -> bool:
len(data) > 0
and is_list_like(data[0])
and getattr(data[0], "ndim", 1) == 1
# TODO(wayd): hack so pyarrow list elements don't expand
and not isinstance(data[0], pa.ListScalar)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think have is list like return False for pyarrow scalar is less hacky?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably true in this particular case, although I'm not sure how it will generalize to all uses of is_list_like. Will do more research

and not (isinstance(data, ExtensionArray) and data.ndim == 2)
)

Expand Down
5 changes: 4 additions & 1 deletion pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,10 @@ def from_blocks(

@classmethod
def from_array(
cls, array: ArrayLike, index: Index, refs: BlockValuesRefs | None = None
cls,
array: ArrayLike,
index: Index,
refs: BlockValuesRefs | None = None,
) -> SingleBlockManager:
"""
Constructor for if we have an array that is not yet a Block.
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
StructAccessor,
)
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.list_ import ListDtype
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
from pandas.core.construction import (
Expand Down Expand Up @@ -494,7 +495,7 @@ def __init__(
if not is_list_like(data):
data = [data]
index = default_index(len(data))
elif is_list_like(data):
elif is_list_like(data) and not isinstance(dtype, ListDtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about nested list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this is a tough one to handle. I'm not sure if something like:

pd.Series([1, 2, 3], index=range(3), dtype=pd.ListDtype())

should raise or broadcast. I think the tests currently want it to broadcast, but we could override that expectation for this array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good example. also indexing is going to be a pain point. a lot of checking for nested listlikes happens in indexing.py and im wary of stumbling blocks there. also indexing in general with a ListDtype index. i think the solution is to interpret lists as sequences and only treat them as a scalar if explicitly wrapped in pa.ListScalar (or something equivalent)

Copy link
Member Author

@WillAyd WillAyd Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally started down that path but I think the problem with expecting pa.ListScalar is that when you select from a ListArray, most users probably expect a plain python list back. So to use that same selection as an indexer I think its hard to avoid the built-in type

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to use that same selection as an indexer

I think we have to just disallow that. Otherwise we have a long tail of places where we expect labels to be non-listlike and nested objects to mean 2D. e.g. df.set_index(list_column).T.set_index([1, 2, 3]) is ambiguous whether it wants three columns to be set as an index or a single column with label [1, 2, 3]. I'm particularly worried about indexing.py code where we check for nested data getting way more complicated if the behavior has to depend on dtypes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

df.set_index(list_column).T.set_index([1, 2, 3]) is ambiguous whether it wants three columns to be set as an index or a single column with label [1, 2, 3]

Do we normally require index elements to be hashable? I think there's an argument to be made that we shouldn't allow this as a index, given Python lists aren't hashable (even though PyArrow lists are)

If we do want to allow this as an index, the example cited is a single column with label [1, 2, 3]

I'm particularly worried about indexing.py code where we check for nested data getting way more complicated if the behavior has to depend on dtypes.

Definitely a valid concern. Do you know what might not be covered in this area by the extension tests?

Copy link
Member

@jbrockmendel jbrockmendel Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we normally require index elements to be hashable?

We don't do any validation at construction-time for an object dtype Index. I expect a bunch of methods will break with non-hashable elements.

I think there's an argument to be made that we shouldn't allow this as a index, given Python lists aren't hashable (even though PyArrow lists are)

Wouldn't that break a lot e.g. obj.groupby(list_dtype_column).method()? I expect there are lots of methods/functions that do a obj.set_index(col).something().reset_index(col) that would become fragile.

The simple solution to all of these problems is to require users to be explicit when they want to treat a list as a scalar.

Definitely a valid concern. Do you know what might not be covered in this area by the extension tests?

I can come up with a few examples if we go down this road, but that part of the code has a lot of paths so it won't be comprehensive. We'll be chasing down corner cases for a long time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that's a good point about the groupby stuff. I don't have a strong objection to requiring an explicit scalar argument, though I guess the follow up question is do we want users passing a pyarrow scalar or do we want to create our own scalar type.

Not sure how other libraries like polars handle this, but @MarcoGorelli might have some insight

com.require_length_match(data, index)

# create/copy the manager
Expand Down
21 changes: 21 additions & 0 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,27 @@ def _format_strings(self) -> list[str]:
return fmt_values


class _NullFormatter(_GenericArrayFormatter):
def _format_strings(self) -> list[str]:
fmt_values = [str(x) for x in self.values]
return fmt_values


class _ListFormatter(_GenericArrayFormatter):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt look like this is used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep dead code - thanks!

def _format_strings(self) -> list[str]:
# TODO(wayd): This doesn't seem right - where should missing values
# be handled
fmt_values = []
for x in self.values:
pyval = x.as_py()
if pyval:
fmt_values.append(pyval)
else:
fmt_values.append("")

return fmt_values


class _Datetime64Formatter(_GenericArrayFormatter):
values: DatetimeArray

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class TestPDApi(Base):
"RangeIndex",
"Series",
"SparseDtype",
"ListDtype",
"StringDtype",
"Timedelta",
"TimedeltaIndex",
Expand Down
7 changes: 0 additions & 7 deletions pandas/tests/extension/list/__init__.py

This file was deleted.

Loading
Loading