All rectangles intersecting the query rectangle need to be computed (consisting of two points)
Just subtract the number of all rectangles from the disjoint ones.
Obviously disjoint rectangles consist of two points on the top, bottom, left and right of the query rectangle.
But there are repetitive parts, that is, the upper left, the lower right, the lower right.
So add the answer back to this part.
Query the number of [d,u] on [l,r] using the chairman tree
#include <iostream>
#include <algorithm>
#include <sstream>
#include <string>
#include <queue>
#include <cstdio>
#include <map>
#include <set>
#include <utility>
#include <stack>
#include <cstring>
#include <cmath>
#include <vector>
#include <ctime>
#include <bitset>
using namespace std;
#define pb push_back
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define ans() printf("%d",ans)
#define ansn() printf("%d\n",ans)
#define anss() printf("%d ",ans)
#define lans() printf("%lld",ans)
#define lanss() printf("%lld ",ans)
#define lansn() printf("%lld\n",ans)
#define fansn() printf("%.10f\n",ans)
#define r0(i,n) for(int i=0;i<(n);++i)
#define r1(i,e) for(int i=1;i<=e;++i)
#define rn(i,e) for(int i=e;i>=1;--i)
#define rsz(i,v) for(int i=0;i<(int)v.size();++i)
#define szz(x) ((int)x.size())
#define mst(abc,bca) memset(abc,bca,sizeof abc)
#define lowbit(a) (a&(-a))
#define all(a) a.begin(),a.end()
#define pii pair<int,int>
#define pli pair<ll,int>
#define mp make_pair
#define lrt rt<<1
#define rrt rt<<1|1
#define X first
#define Y second
#define PI (acos(-1.0))
#define sqr(a) ((a)*(a))
typedef long long ll;
typedef unsigned long long ull;
const ll mod = 1000000000+7;
const double eps=1e-9;
const int inf=0x3f3f3f3f;
const ll infl = 10000000000000000;
const int maxn= 200000+10;
const int maxm = maxn*21+10;
//Pretests passed
int in(int &ret)
{
char c;
int sgn ;
if(c=getchar(),c==EOF)return -1;
while(c!='-'&&(c<'0'||c>'9'))c=getchar();
sgn = (c=='-')?-1:1;
ret = (c=='-')?0:(c-'0');
while(c=getchar(),c>='0'&&c<='9')ret = ret*10+(c-'0');
ret *=sgn;
return 1;
}
int root[maxn];
struct Seg
{
int cnt,lch,rch;
}seg[maxm];
int tot;
void update(int &rt,int l,int r,int x)
{
int last = rt;
seg[rt=++tot] = seg[last];
++seg[rt].cnt;
if(l==r)return ;
int mid = (l+r)>>1;
if(x<=mid)update(seg[rt].lch,l,mid,x);
else update(seg[rt].rch,mid+1,r,x);
}
int query(int rt,int L,int R,int l,int r)
{
if(l<=L&&R<=r)return seg[rt].cnt;
int m = (L+R)>>1;
if(r<=m)return query(seg[rt].lch,L,m,l,r);
if(m<l)return query(seg[rt].rch,m+1,R,l,r);
return query(seg[rt].lch,L,m,l,r) + query(seg[rt].rch,m+1,R,l,r);
}
int n;
int query(int xl,int xr,int yl,int yr)
{
return query(root[xr],1,n,yl,yr) - query(root[xl-1],1,n,yl,yr);
}
ll cal(int x)
{
return (1LL*(x-1)*x)>>1;
}
int main()
{
#ifdef LOCAL
freopen("input.txt","r",stdin);
// freopen("output.txt","w",stdout);
#endif // LOCAL
int q;
sdd(n,q);
r1(i,n)
{
int x;
sd(x);
root[i] = root[i-1];
update(root[i],1,n,x);
}
while(q--)
{
int x1,y1,x2,y2;
sdd(x1,y1),sdd(x2,y2);
ll ans = cal(n);
ans -= cal(x1-1) + cal(n-x2) + cal(y1-1) +cal(n-y2);
int lu = 1<x1&&1<y1? query(1,x1-1,1,y1-1) : 0 ;
int ld = 1<x1&&y2<n? query(1,x1-1,y2+1,n) : 0 ;
int ru = x2<n&&y1>1? query(x2+1,n,1,y1-1) : 0 ;
int rd = x2<n&&y2<n? query(x2+1,n,y2+1,n) : 0 ;
ans += cal(lu) + cal(ld) + cal(ru) + cal(rd);
lansn();
}
return 0;
}