hand
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
import operator
|
||||
from itertools import accumulate
|
||||
from typing import Optional
|
||||
|
||||
from ._memref_ops_gen import *
|
||||
from ._memref_ops_gen import _Dialect
|
||||
from ._ods_common import _dispatch_mixed_values, MixedValues
|
||||
from ..ir import (
|
||||
IndexType,
|
||||
IntegerType,
|
||||
MemRefType,
|
||||
ShapedType,
|
||||
StridedLayoutAttr,
|
||||
Value,
|
||||
)
|
||||
from . import arith
|
||||
|
||||
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner, arith.ConstantOp)
|
||||
and isinstance(i.type, (IntegerType, IndexType))
|
||||
)
|
||||
|
||||
|
||||
def _is_static_int_like(i):
|
||||
return (
|
||||
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
|
||||
) or _is_constant_int_like(i)
|
||||
|
||||
|
||||
def _infer_memref_subview_result_type(
|
||||
source_memref_type, offsets, static_sizes, static_strides
|
||||
):
|
||||
source_strides, source_offset = source_memref_type.get_strides_and_offset()
|
||||
# "canonicalize" from tuple|list -> list
|
||||
offsets, static_sizes, static_strides, source_strides = map(
|
||||
list, (offsets, static_sizes, static_strides, source_strides)
|
||||
)
|
||||
|
||||
if not all(
|
||||
all(_is_static_int_like(i) for i in s)
|
||||
for s in [
|
||||
static_sizes,
|
||||
static_strides,
|
||||
source_strides,
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
"Only inferring from python or mlir integer constant is supported."
|
||||
)
|
||||
|
||||
for s in [offsets, static_sizes, static_strides]:
|
||||
for idx, i in enumerate(s):
|
||||
if _is_constant_int_like(i):
|
||||
s[idx] = i.owner.opview.literal_value
|
||||
|
||||
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
|
||||
target_offset = ShapedType.get_dynamic_size()
|
||||
else:
|
||||
target_offset = source_offset
|
||||
for offset, target_stride in zip(offsets, source_strides):
|
||||
target_offset += offset * target_stride
|
||||
|
||||
target_strides = []
|
||||
for source_stride, static_stride in zip(source_strides, static_strides):
|
||||
target_strides.append(source_stride * static_stride)
|
||||
|
||||
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
|
||||
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
|
||||
if target_strides == default_strides and target_offset == 0:
|
||||
layout = None
|
||||
else:
|
||||
layout = StridedLayoutAttr.get(target_offset, target_strides)
|
||||
return (
|
||||
offsets,
|
||||
static_sizes,
|
||||
static_strides,
|
||||
MemRefType.get(
|
||||
static_sizes,
|
||||
source_memref_type.element_type,
|
||||
layout,
|
||||
source_memref_type.memory_space,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_generated_subview = subview
|
||||
|
||||
|
||||
def subview(
|
||||
source: Value,
|
||||
offsets: MixedValues,
|
||||
sizes: MixedValues,
|
||||
strides: MixedValues,
|
||||
*,
|
||||
result_type: Optional[MemRefType] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if offsets is None:
|
||||
offsets = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
if strides is None:
|
||||
strides = []
|
||||
source_strides, source_offset = source.type.get_strides_and_offset()
|
||||
if result_type is None and all(
|
||||
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
|
||||
):
|
||||
# If any are arith.constant results then this will canonicalize to python int
|
||||
# (which can then be used to fully specify the subview).
|
||||
(
|
||||
offsets,
|
||||
sizes,
|
||||
strides,
|
||||
result_type,
|
||||
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
|
||||
elif result_type is None:
|
||||
raise ValueError(
|
||||
"mixed static/dynamic offset/sizes/strides requires explicit result type."
|
||||
)
|
||||
|
||||
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
|
||||
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
|
||||
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
|
||||
|
||||
return _generated_subview(
|
||||
result_type,
|
||||
source,
|
||||
offsets,
|
||||
sizes,
|
||||
strides,
|
||||
static_offsets,
|
||||
static_sizes,
|
||||
static_strides,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
Reference in New Issue
Block a user