// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-ND
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-GATHER

gpu.module @xevm_module {
gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
    {in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32>
  gpu.return %0 : vector<8xf32>
}

// LOAD-ND-LABEL:  @load_1D_vector(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// LOAD-ND:        %[[STEP:.+]] = vector.step : vector<8xindex>
// LOAD-ND-COUNT2: arith.muli {{.*}} : index
// LOAD-ND-COUNT2: arith.addi {{.*}} : index
// LOAD-ND:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
// LOAD-ND:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// LOAD-ND:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// LOAD-ND:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-ND:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>

// LOAD-GATHER-LABEL:  @load_1D_vector(
// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>

}

// -----
gpu.module @xevm_module {
gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
    %offset: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
    {in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @load_2D_vector(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND-SAME:     : memref<f32> -> index
// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
// LOAD-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME:     boundary_check = false
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND:        return %[[VEC]]

// LOAD-GATHER-LABEL:  @load_2D_vector(
// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// LOAD-GATHER-COUNT2: vector.step
// LOAD-GATHER-COUNT2: vector.shape_cast
// LOAD-GATHER-COUNT2: vector.broadcast
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>

}


// -----
gpu.module @xevm_module {
gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
    %offset: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %offset], %c0
    {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @load_zero_pad_out_of_bounds(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND:        return %[[VEC]]

// LOAD-GATHER-LABEL:  @load_zero_pad_out_of_bounds(
// LOAD-GATHER:        vector.transfer_read

}


// -----
gpu.module @xevm_module {
gpu.func @load_transposed(%source: memref<32x64xf32>,
    %i: index, %j: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%i, %j], %c0
    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
    in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @load_transposed(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME:   %[[OFFSET1:.+]]: index, 
// LOAD-ND-SAME:   %[[OFFSET2:.+]]: index  
// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
// LOAD-ND-SAME:     -> vector<8x16xf32>
// LOAD-ND:        return %[[VEC]]


// LOAD-GATHER-LABEL:  @load_transposed(
// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// LOAD-GATHER-COUNT2:  vector.step
// LOAD-GATHER-COUNT2:  vector.shape_cast
// LOAD-GATHER-COUNT2: vector.broadcast
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>

}

// -----
gpu.module @xevm_module {
gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%i, %j, %k], %c0
    {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}
// LOAD-ND-LABEL:  @load_dynamic_source(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// LOAD-ND-SAME:                    strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME:                      #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND:        return %[[VEC]]


// LOAD-GATHER-LABEL:  @load_dynamic_source(
// LOAD-GATHER-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>,
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// LOAD-GATHER:        memref.extract_strided_metadata %[[ARG0]]
// LOAD-GATHER-COUNT2: vector.step
// LOAD-GATHER-COUNT2: vector.shape_cast
// LOAD-GATHER-COUNT2: vector.broadcast
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
}

// -----
gpu.module @xevm_module {
gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%i, %j, %k], %c0
    {in_bounds = [true, true]} : memref<?x8x16xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @load_dynamic_source2(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [8, 16], strides : [16, 1] :
// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND:        return %[[VEC]] : vector<8x16xf32>

// LOAD-GATHER-LABEL:  @load_dynamic_source2(
// LOAD-GATHER-DAG:    %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
// LOAD-GATHER-COUNT2: vector.step
// LOAD-GATHER-COUNT2: vector.shape_cast
// LOAD-GATHER-COUNT2: vector.broadcast
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
// LOAD-GATHER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 

}

// -----
gpu.module @xevm_module {
gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
    %i: index, %j: index, %k: index, %l: index, %m: index) -> vector<2x4x8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%i, %j, %k, %l, %m], %c0
    {in_bounds = [true, true, true, true]} : memref<?x?x?x?x?xf32>, vector<2x4x8x16xf32>
  gpu.return %0 : vector<2x4x8x16xf32>
}

// LOAD-ND-LABEL:  @load_dynamic_source3(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @load_dynamic_source3(
// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<?x?x?x?x?xf32>
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
// LOAD-GATHER:        memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
// LOAD-GATHER-COUNT4: vector.step
// LOAD-GATHER-COUNT3: vector.broadcast
// LOAD-GATHER-COUNT4: vector.shape_cast
// LOAD-GATHER-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
// LOAD-GATHER:        return %[[VEC]]
}

// -----
gpu.module @xevm_module {
gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
    %offset: index, %arg2: index) -> vector<8x16x32xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
    {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
  gpu.return %0 : vector<8x16x32xf32>
}

// LOAD-ND-LABEL:  @load_high_dim_vector(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @load_high_dim_vector(
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
// LOAD-GATHER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
// LOAD-GATHER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
// LOAD-GATHER:        %[[C2048:.+]] = arith.constant 2048 : index
// LOAD-GATHER:        %[[C64:.+]] = arith.constant 64 : index
// LOAD-GATHER-COUNT3: vector.step
// LOAD-GATHER-COUNT3: vector.shape_cast
// LOAD-GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>

}

// -----
gpu.module @xevm_module {
gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
    %offset: index) -> vector<8x16xf16> {
  %c0 = arith.constant 0.0 : f16
  %0 = vector.transfer_read %source[%offset, %offset], %c0
    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
    in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
  gpu.return %0 : vector<8x16xf16>
}

// LOAD-ND-LABEL:  @load_transpose_f16(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @load_transpose_f16(
// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,
// LOAD-GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// LOAD-GATHER-COUNT2:  vector.step
// LOAD-GATHER-COUNT2:  vector.shape_cast
// LOAD-GATHER-COUNT2: vector.broadcast
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
}

// -----
gpu.module @xevm_module {
gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
    %offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) {
  %c1 = arith.constant 1.0 : f32
  %0 = vector.transfer_read %source[%offset, %arg2], %c1
    {in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32>
  %1 = vector.transfer_read %source[%arg2, %offset], %pad
    {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
  gpu.return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
}

// LOAD-ND-LABEL:    @no_load_out_of_bounds_non_zero_pad(
// LOAD-ND-COUNT-2: vector.transfer_read

// LOAD-GATHER-LABEL: @no_load_out_of_bounds_non_zero_pad(
// LOAD-GATHER-COUNT-2: vector.transfer_read
}

// -----
gpu.module @xevm_module {
gpu.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
    %offset: index) -> vector<8xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
    {in_bounds = [false]} : memref<8x16x32xf32>, vector<8xf32>
  gpu.return %0 : vector<8xf32>
}

// LOAD-ND-LABEL:  @no_load_out_of_bounds_1D_vector(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @no_load_out_of_bounds_1D_vector(
// LOAD-GATHER:        vector.transfer_read
}

// -----
gpu.module @xevm_module {
gpu.func @no_load_masked(%source : memref<4xf32>,
    %offset : index) -> vector<4xf32> {
  %c0 = arith.constant 0.0 : f32
  %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
  %0 = vector.transfer_read %source[%offset], %c0, %mask
    {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
  gpu.return %0 : vector<4xf32>
}

// LOAD-ND-LABEL:  @no_load_masked(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @no_load_masked(
// LOAD-GATHER:        vector.transfer_read
}

// -----
gpu.module @xevm_module {
gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
    %offset: index, %arg2: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %arg2], %c0
    {in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @no_load_tensor(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @no_load_tensor(
// LOAD-GATHER:        vector.transfer_read
}


// -----
gpu.module @xevm_module {
gpu.func @no_load_non_unit_inner_stride(
    %source: memref<32xf32, strided<[?], offset: ?>>,
    %offset: index) -> vector<8xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]}
    : memref<32xf32, strided<[?], offset: ?>>, vector<8xf32>
  gpu.return %0 : vector<8xf32>
}

// LOAD-ND-LABEL:  @no_load_non_unit_inner_stride(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @no_load_non_unit_inner_stride(
// LOAD-GATHER:        vector.transfer_read
}


// -----
gpu.module @xevm_module {
gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
    %offset: index) -> vector<8x16xf32> {
  %c0 = arith.constant 0.0 : f32
  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
    {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
    in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32>
  gpu.return %0 : vector<8x16xf32>
}

// LOAD-ND-LABEL:  @no_load_unsupported_map(
// LOAD-ND:        vector.transfer_read

// LOAD-GATHER-LABEL:  @no_load_unsupported_map(
// LOAD-GATHER:        vector.transfer_read
}

// -----
gpu.module @xevm_module {
gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
  %c0 = arith.constant 0.0 : f16
  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
    {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
  gpu.return %0 : vector<8xf16>
}

// LOAD-ND-LABEL:  @load_from_subview_1D(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
// LOAD-ND:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
// LOAD-ND:        %[[STEP:.+]] = vector.step : vector<8xindex>
// LOAD-ND:        arith.muli {{.*}} : index
// LOAD-ND:        arith.addi %[[OFFSET]]{{.*}} : index
// LOAD-ND:        arith.addi {{.*}} : index
// LOAD-ND:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
// LOAD-ND:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// LOAD-ND:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
// LOAD-ND:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-ND:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>

// LOAD-GATHER-LABEL:  @load_from_subview_1D(
// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-GATHER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
// LOAD-GATHER:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
// LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
// LOAD-GATHER:        arith.muli {{.*}} : index
// LOAD-GATHER:        arith.addi %[[OFFSET]]{{.*}} : index
// LOAD-GATHER:        arith.addi {{.*}} : index
// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
}

// -----
gpu.module @xevm_module {
gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8x16xf16> {
  %c0 = arith.constant 0.0 : f16
  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
  %0 = vector.transfer_read %subview[%off2, %off2], %c0
    {in_bounds = [true, true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8x16xf16>
  gpu.return %0 : vector<8x16xf16>
}

// LOAD-ND-LABEL:  @load_from_subview_2D(
// LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 2 : index
// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256, 256], strides : [4096, 1] :
// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
// LOAD-ND:        return %[[VEC]]

// LOAD-GATHER-LABEL:  @load_from_subview_2D(
// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-GATHER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-GATHER:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
// LOAD-GATHER-COUNT2: vector.step
// LOAD-GATHER-COUNT2: vector.shape_cast
// LOAD-GATHER-COUNT2: vector.broadcast
// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<8x16xindex>
// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
}
